From 16ec8af1533d41336c4952c40fbd01c919f31852 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 13 Jan 2025 06:50:44 +0000 Subject: [PATCH 01/59] Refine logger output text --- llumnix/manager.py | 1 + requirements/requirements_bladellm.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/llumnix/manager.py b/llumnix/manager.py index 01675ec0..d2778495 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -45,6 +45,7 @@ WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) from llumnix.launcher import Launcher + logger = init_logger(__name__) # TODO(s5u13b): Handle exception of ray operations. diff --git a/requirements/requirements_bladellm.txt b/requirements/requirements_bladellm.txt index ecaa9301..a55b6c75 100644 --- a/requirements/requirements_bladellm.txt +++ b/requirements/requirements_bladellm.txt @@ -5,3 +5,4 @@ pandas matplotlib pyyaml yacs +loguru From e99e51900799ca46ba1db595ffde0a701cf1ceae Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 13 Jan 2025 09:12:17 +0000 Subject: [PATCH 02/59] Customize prefix for actor logs --- llumnix/manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llumnix/manager.py b/llumnix/manager.py index d2778495..0c3a126e 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -146,6 +146,10 @@ def __init__(self, if self.manager_args.enable_pd_disagg: asyncio.create_task(self._check_pd_deployment_states_loop(CHECK_DEPLOYMENT_STATES_INTERVAL)) + def __repr__(self): + # Customizing prefixes for Actor logs. + return f"{self.__class__.__name__}(node_id={self.node_id[:5]})" + async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: logger.warning("No instance available now, sleep {}s, " From 11377bf5c84e643ee3b9b8da1ef4a1f1440687bc Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 13 Jan 2025 12:11:59 +0000 Subject: [PATCH 03/59] Upgrade logger to vLLM v0.6.6.post1 --- requirements/requirements_bladellm.txt | 1 - tests/unit_test/test_logger.py | 234 +++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 tests/unit_test/test_logger.py diff --git a/requirements/requirements_bladellm.txt b/requirements/requirements_bladellm.txt index a55b6c75..ecaa9301 100644 --- a/requirements/requirements_bladellm.txt +++ b/requirements/requirements_bladellm.txt @@ -5,4 +5,3 @@ pandas matplotlib pyyaml yacs -loguru diff --git a/tests/unit_test/test_logger.py b/tests/unit_test/test_logger.py new file mode 100644 index 00000000..c967d90c --- /dev/null +++ b/tests/unit_test/test_logger.py @@ -0,0 +1,234 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from vLLM(v0.6.6.post1): +# https://github.com/vllm-project/vllm/blob/cd8249903f189c5f06424e67dbc6512ca494a046/tests/test_logger.py + +import json +import logging +import os +import sys +import tempfile +from json.decoder import JSONDecodeError +from tempfile import NamedTemporaryFile +from typing import Any +from unittest.mock import patch +from uuid import uuid4 + +import pytest + +from llumnix.logger import (_DATE_FORMAT, _FORMAT, _configure_llumnix_root_logger, + enable_trace_function_call, init_logger) +from llumnix.logging_utils import NewLineFormatter + + +def f1(x): + return f2(x) + + +def f2(x): + return x + + +def test_trace_function_call(): + # pylint: disable=unused-variable + fd, path = tempfile.mkstemp() + cur_dir = os.path.dirname(__file__) + enable_trace_function_call(path, cur_dir) + f1(1) + # pylint: disable=unspecified-encoding + with open(path) as f: + content = f.read() + + assert "f1" in content + assert "f2" in content + sys.settrace(None) + os.remove(path) + + +def test_default_llumnix_root_logger_configuration(): + """This test presumes that LLUMNIX_CONFIGURE_LOGGING (default: True) and + LLUMNIX_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + logger = logging.getLogger("llumnix") + assert logger.level == logging.DEBUG + assert not logger.propagate + + handler = logger.handlers[0] + assert isinstance(handler, logging.StreamHandler) + assert handler.stream == sys.stdout + # we use DEBUG level for testing by default + # assert handler.level == logging.INFO + + formatter = handler.formatter + assert formatter is not None + assert isinstance(formatter, NewLineFormatter) + assert formatter._fmt == _FORMAT + assert formatter.datefmt == _DATE_FORMAT + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) +@patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) +def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): + """This test presumes that LLUMNIX_CONFIGURE_LOGGING (default: True) and + LLUMNIX_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + root_logger = logging.getLogger("llumnix") + root_handler = root_logger.handlers[0] + + unique_name = f"llumnix.{uuid4()}" + logger = init_logger(unique_name) + assert logger.name == unique_name + assert logger.level == logging.NOTSET + assert not logger.handlers + assert logger.propagate + + message = "Hello, world!" + with patch.object(root_handler, "emit") as root_handle_mock: + logger.info(message) + + root_handle_mock.assert_called_once() + _, call_args, _ = root_handle_mock.mock_calls[0] + log_record = call_args[0] + assert unique_name == log_record.name + assert message == log_record.msg + assert message == log_record.msg + assert log_record.levelno == logging.INFO + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 0) +@patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) +def test_logger_configuring_can_be_disabled(): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + + with patch("llumnix.logger.dictConfig") as dict_config_mock: + _configure_llumnix_root_logger() + dict_config_mock.assert_not_called() + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) +@patch( + "llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", + "/if/there/is/a/file/here/then/you/did/this/to/yourself.json", +) +def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with pytest.raises(RuntimeError) as ex_info: + _configure_llumnix_root_logger() + assert ex_info.type == RuntimeError # noqa: E721 + assert "File does not exist" in str(ex_info) + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) +def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write("---\nloggers: []\nversion: 1") + logging_config_file.flush() + with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(JSONDecodeError) as ex_info: + _configure_llumnix_root_logger() + assert ex_info.type == JSONDecodeError + assert "Expecting value" in str(ex_info) + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) +@pytest.mark.parametrize("unexpected_config", ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, +)) +def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( + unexpected_config: Any): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(unexpected_config)) + logging_config_file.flush() + with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(ValueError) as ex_info: + _configure_llumnix_root_logger() + assert ex_info.type == ValueError # noqa: E721 + assert "Invalid logging config. Expected Dict, got" in str(ex_info) + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) +def test_custom_logging_config_is_parsed_and_used_when_provided(): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "llumnix.test_logger.logger": { + "handlers": [], + "propagate": False, + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", + logging_config_file.name), patch( + "llumnix.logger.dictConfig") as dict_config_mock: + _configure_llumnix_root_logger() + dict_config_mock.assert_called_with(valid_logging_config) + + +@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 0) +def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): + """This test calls _configure_llumnix_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "llumnix.test_logger.logger": { + "handlers": [], + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(RuntimeError) as ex_info: + _configure_llumnix_root_logger() + assert ex_info.type is RuntimeError + expected_message_snippet = ( + "LLUMNIX_CONFIGURE_LOGGING evaluated to false, but " + "LLUMNIX_LOGGING_CONFIG_PATH was given.") + assert expected_message_snippet in str(ex_info) + + # Remember! The root logger is assumed to have been configured as + # though LLUMNIX_CONFIGURE_LOGGING=1 and LLUMNIX_LOGGING_CONFIG_PATH=None. + root_logger = logging.getLogger("llumnix") + other_logger_name = f"llumnix.test_logger.{uuid4()}" + other_logger = init_logger(other_logger_name) + assert other_logger.handlers != root_logger.handlers + assert other_logger.level != root_logger.level + assert other_logger.propagate From bc2b83b2a30d5068c91e29cdfad6727bf11d8204 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 02:18:35 +0000 Subject: [PATCH 04/59] Reorganize logger --- llumnix/manager.py | 2 +- tests/unit_test/test_logger.py | 234 --------------------------------- 2 files changed, 1 insertion(+), 235 deletions(-) delete mode 100644 tests/unit_test/test_logger.py diff --git a/llumnix/manager.py b/llumnix/manager.py index 0c3a126e..0ca84251 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -148,7 +148,7 @@ def __init__(self, def __repr__(self): # Customizing prefixes for Actor logs. - return f"{self.__class__.__name__}(node_id={self.node_id[:5]})" + return f"{self.__class__.__name__}(nid={self.node_id[:5]})" async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: diff --git a/tests/unit_test/test_logger.py b/tests/unit_test/test_logger.py deleted file mode 100644 index c967d90c..00000000 --- a/tests/unit_test/test_logger.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2024, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Adapted from vLLM(v0.6.6.post1): -# https://github.com/vllm-project/vllm/blob/cd8249903f189c5f06424e67dbc6512ca494a046/tests/test_logger.py - -import json -import logging -import os -import sys -import tempfile -from json.decoder import JSONDecodeError -from tempfile import NamedTemporaryFile -from typing import Any -from unittest.mock import patch -from uuid import uuid4 - -import pytest - -from llumnix.logger import (_DATE_FORMAT, _FORMAT, _configure_llumnix_root_logger, - enable_trace_function_call, init_logger) -from llumnix.logging_utils import NewLineFormatter - - -def f1(x): - return f2(x) - - -def f2(x): - return x - - -def test_trace_function_call(): - # pylint: disable=unused-variable - fd, path = tempfile.mkstemp() - cur_dir = os.path.dirname(__file__) - enable_trace_function_call(path, cur_dir) - f1(1) - # pylint: disable=unspecified-encoding - with open(path) as f: - content = f.read() - - assert "f1" in content - assert "f2" in content - sys.settrace(None) - os.remove(path) - - -def test_default_llumnix_root_logger_configuration(): - """This test presumes that LLUMNIX_CONFIGURE_LOGGING (default: True) and - LLUMNIX_LOGGING_CONFIG_PATH (default: None) are not configured and default - behavior is activated.""" - logger = logging.getLogger("llumnix") - assert logger.level == logging.DEBUG - assert not logger.propagate - - handler = logger.handlers[0] - assert isinstance(handler, logging.StreamHandler) - assert handler.stream == sys.stdout - # we use DEBUG level for testing by default - # assert handler.level == logging.INFO - - formatter = handler.formatter - assert formatter is not None - assert isinstance(formatter, NewLineFormatter) - assert formatter._fmt == _FORMAT - assert formatter.datefmt == _DATE_FORMAT - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) -@patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) -def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): - """This test presumes that LLUMNIX_CONFIGURE_LOGGING (default: True) and - LLUMNIX_LOGGING_CONFIG_PATH (default: None) are not configured and default - behavior is activated.""" - root_logger = logging.getLogger("llumnix") - root_handler = root_logger.handlers[0] - - unique_name = f"llumnix.{uuid4()}" - logger = init_logger(unique_name) - assert logger.name == unique_name - assert logger.level == logging.NOTSET - assert not logger.handlers - assert logger.propagate - - message = "Hello, world!" - with patch.object(root_handler, "emit") as root_handle_mock: - logger.info(message) - - root_handle_mock.assert_called_once() - _, call_args, _ = root_handle_mock.mock_calls[0] - log_record = call_args[0] - assert unique_name == log_record.name - assert message == log_record.msg - assert message == log_record.msg - assert log_record.levelno == logging.INFO - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 0) -@patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) -def test_logger_configuring_can_be_disabled(): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however mocks are used to ensure no changes in behavior or - configuration occur.""" - - with patch("llumnix.logger.dictConfig") as dict_config_mock: - _configure_llumnix_root_logger() - dict_config_mock.assert_not_called() - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) -@patch( - "llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", - "/if/there/is/a/file/here/then/you/did/this/to/yourself.json", -) -def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however it fails before any change in behavior or - configuration occurs.""" - with pytest.raises(RuntimeError) as ex_info: - _configure_llumnix_root_logger() - assert ex_info.type == RuntimeError # noqa: E721 - assert "File does not exist" in str(ex_info) - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) -def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however it fails before any change in behavior or - configuration occurs.""" - with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: - logging_config_file.write("---\nloggers: []\nversion: 1") - logging_config_file.flush() - with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", - logging_config_file.name): - with pytest.raises(JSONDecodeError) as ex_info: - _configure_llumnix_root_logger() - assert ex_info.type == JSONDecodeError - assert "Expecting value" in str(ex_info) - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) -@pytest.mark.parametrize("unexpected_config", ( - "Invalid string", - [{ - "version": 1, - "loggers": [] - }], - 0, -)) -def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( - unexpected_config: Any): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however it fails before any change in behavior or - configuration occurs.""" - with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: - logging_config_file.write(json.dumps(unexpected_config)) - logging_config_file.flush() - with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", - logging_config_file.name): - with pytest.raises(ValueError) as ex_info: - _configure_llumnix_root_logger() - assert ex_info.type == ValueError # noqa: E721 - assert "Invalid logging config. Expected Dict, got" in str(ex_info) - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 1) -def test_custom_logging_config_is_parsed_and_used_when_provided(): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however mocks are used to ensure no changes in behavior or - configuration occur.""" - valid_logging_config = { - "loggers": { - "llumnix.test_logger.logger": { - "handlers": [], - "propagate": False, - } - }, - "version": 1 - } - with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: - logging_config_file.write(json.dumps(valid_logging_config)) - logging_config_file.flush() - with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", - logging_config_file.name), patch( - "llumnix.logger.dictConfig") as dict_config_mock: - _configure_llumnix_root_logger() - dict_config_mock.assert_called_with(valid_logging_config) - - -@patch("llumnix.logger.LLUMNIX_CONFIGURE_LOGGING", 0) -def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): - """This test calls _configure_llumnix_root_logger again to test custom logging - config behavior, however mocks are used to ensure no changes in behavior or - configuration occur.""" - valid_logging_config = { - "loggers": { - "llumnix.test_logger.logger": { - "handlers": [], - } - }, - "version": 1 - } - with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: - logging_config_file.write(json.dumps(valid_logging_config)) - logging_config_file.flush() - with patch("llumnix.logger.LLUMNIX_LOGGING_CONFIG_PATH", - logging_config_file.name): - with pytest.raises(RuntimeError) as ex_info: - _configure_llumnix_root_logger() - assert ex_info.type is RuntimeError - expected_message_snippet = ( - "LLUMNIX_CONFIGURE_LOGGING evaluated to false, but " - "LLUMNIX_LOGGING_CONFIG_PATH was given.") - assert expected_message_snippet in str(ex_info) - - # Remember! The root logger is assumed to have been configured as - # though LLUMNIX_CONFIGURE_LOGGING=1 and LLUMNIX_LOGGING_CONFIG_PATH=None. - root_logger = logging.getLogger("llumnix") - other_logger_name = f"llumnix.test_logger.{uuid4()}" - other_logger = init_logger(other_logger_name) - assert other_logger.handlers != root_logger.handlers - assert other_logger.level != root_logger.level - assert other_logger.propagate From 8b1533defde0ffbb683d652b076130a2e28470ee Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 02:54:34 +0000 Subject: [PATCH 05/59] Remove date format --- tests/unit_test/logging/test_logger.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_test/logging/test_logger.py b/tests/unit_test/logging/test_logger.py index c6fec057..6f99f1ec 100644 --- a/tests/unit_test/logging/test_logger.py +++ b/tests/unit_test/logging/test_logger.py @@ -48,6 +48,7 @@ def test_default_llumnix_root_logger_configuration(): assert isinstance(formatter, NewLineFormatter) assert formatter._fmt == _FORMAT + @patch("llumnix.logging.logger.LLUMNIX_CONFIGURE_LOGGING", 1) @patch("llumnix.logging.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): From 4ef0c7d2e2e2b76b90ad2862c4d8529bff6781a4 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 03:37:17 +0000 Subject: [PATCH 06/59] Add constants module --- llumnix/entrypoints/setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index ea504275..42a2f32c 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -35,7 +35,6 @@ logger = init_logger(__name__) - def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: head_node_ip = os.getenv('HEAD_NODE_IP') node_ip_address = get_ip_address() From df92ab817a40c851b4035c2ec6aabc36791f70b4 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 06:30:46 +0000 Subject: [PATCH 07/59] Log ray id for logging --- llumnix/entrypoints/setup.py | 1 + llumnix/manager.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index 42a2f32c..ea504275 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -35,6 +35,7 @@ logger = init_logger(__name__) + def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: head_node_ip = os.getenv('HEAD_NODE_IP') node_ip_address = get_ip_address() diff --git a/llumnix/manager.py b/llumnix/manager.py index 0ca84251..d2778495 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -146,10 +146,6 @@ def __init__(self, if self.manager_args.enable_pd_disagg: asyncio.create_task(self._check_pd_deployment_states_loop(CHECK_DEPLOYMENT_STATES_INTERVAL)) - def __repr__(self): - # Customizing prefixes for Actor logs. - return f"{self.__class__.__name__}(nid={self.node_id[:5]})" - async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: logger.warning("No instance available now, sleep {}s, " From 24383cd18bc111642657a2e5d59abf96a846cf0a Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 09:59:24 +0000 Subject: [PATCH 08/59] Refine logging handlers configuration --- llumnix/envs.py | 1 - llumnix/logging/logger.py | 26 ++++++++++++++++++++++++++ tests/unit_test/logging/test_logger.py | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/llumnix/envs.py b/llumnix/envs.py index ca7aac1c..6a8d2d23 100644 --- a/llumnix/envs.py +++ b/llumnix/envs.py @@ -44,7 +44,6 @@ # if set, llumnix will routing all logs to stream "LLUMNIX_LOG_STREAM": lambda: os.getenv("LLUMNIX_LOG_STREAM", "1"), - # if set, llumnix will routing all node logs to this path "LLUMNIX_LOG_NODE_PATH": lambda: os.getenv("LLUMNIX_LOG_NODE_PATH", ""), diff --git a/llumnix/logging/logger.py b/llumnix/logging/logger.py index a3491858..5e8bdb52 100644 --- a/llumnix/logging/logger.py +++ b/llumnix/logging/logger.py @@ -131,7 +131,33 @@ def _configure_llumnix_root_logger() -> None: "implies LLUMNIX_CONFIGURE_LOGGING. Please enable " "LLUMNIX_CONFIGURE_LOGGING or unset LLUMNIX_LOGGING_CONFIG_PATH.") + print(f"LLUMNIX_CONFIGURE_LOGGING: {LLUMNIX_CONFIGURE_LOGGING}") + print(f"LLUMNIX_LOG_STREAM: {LLUMNIX_LOG_STREAM}") + print(f"LLUMNIX_LOG_NODE_PATH: {LLUMNIX_LOG_NODE_PATH}") + if LLUMNIX_CONFIGURE_LOGGING: + if LLUMNIX_LOG_STREAM: + print(f"LLUMNIX_LOG_STREAM: {LLUMNIX_LOG_STREAM}") + DEFAULT_LOGGING_CONFIG["handlers"]["stream"] = { + "class": "logging.StreamHandler", + "formatter": "llumnix", + "level": LLUMNIX_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + } + DEFAULT_LOGGING_CONFIG["loggers"]["llumnix"]["handlers"].append("stream") + + if LLUMNIX_LOG_NODE_PATH: + print(f"LLUMNIX_LOG_NODE_PATH: {LLUMNIX_LOG_NODE_PATH}") + DEFAULT_LOGGING_CONFIG["handlers"]["file"] = { + "class": "llumnix.logging.NodeFileHandler", + "formatter": "llumnix", + "level": LLUMNIX_LOGGING_LEVEL, + "base_path": LLUMNIX_LOG_NODE_PATH, + } + DEFAULT_LOGGING_CONFIG["loggers"]["llumnix"]["handlers"].append("file") + + print(f"DEFAULT_LOGGING_CONFIG: {DEFAULT_LOGGING_CONFIG}") + logging_config = DEFAULT_LOGGING_CONFIG if LLUMNIX_LOGGING_CONFIG_PATH: diff --git a/tests/unit_test/logging/test_logger.py b/tests/unit_test/logging/test_logger.py index 6f99f1ec..cf1398a8 100644 --- a/tests/unit_test/logging/test_logger.py +++ b/tests/unit_test/logging/test_logger.py @@ -24,6 +24,7 @@ from uuid import uuid4 import pytest +from unittest.mock import MagicMock from llumnix.logging.logger import _FORMAT, _configure_llumnix_root_logger, init_logger from llumnix.logging import NewLineFormatter @@ -48,7 +49,6 @@ def test_default_llumnix_root_logger_configuration(): assert isinstance(formatter, NewLineFormatter) assert formatter._fmt == _FORMAT - @patch("llumnix.logging.logger.LLUMNIX_CONFIGURE_LOGGING", 1) @patch("llumnix.logging.logger.LLUMNIX_LOGGING_CONFIG_PATH", None) def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): From a4df07c80c43a15d00dc1c5f943fad96a4cb73da Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 10:01:47 +0000 Subject: [PATCH 09/59] Fix lint --- llumnix/logging/logger.py | 2 +- tests/unit_test/logging/test_logger.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/llumnix/logging/logger.py b/llumnix/logging/logger.py index 5e8bdb52..becb6c61 100644 --- a/llumnix/logging/logger.py +++ b/llumnix/logging/logger.py @@ -155,7 +155,7 @@ def _configure_llumnix_root_logger() -> None: "base_path": LLUMNIX_LOG_NODE_PATH, } DEFAULT_LOGGING_CONFIG["loggers"]["llumnix"]["handlers"].append("file") - + print(f"DEFAULT_LOGGING_CONFIG: {DEFAULT_LOGGING_CONFIG}") logging_config = DEFAULT_LOGGING_CONFIG diff --git a/tests/unit_test/logging/test_logger.py b/tests/unit_test/logging/test_logger.py index cf1398a8..c6fec057 100644 --- a/tests/unit_test/logging/test_logger.py +++ b/tests/unit_test/logging/test_logger.py @@ -24,7 +24,6 @@ from uuid import uuid4 import pytest -from unittest.mock import MagicMock from llumnix.logging.logger import _FORMAT, _configure_llumnix_root_logger, init_logger from llumnix.logging import NewLineFormatter From f3c427b3ba634135db5d93c34e81ff4c0515879d Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 14 Jan 2025 10:32:24 +0000 Subject: [PATCH 10/59] Refine logger --- llumnix/logging/logger.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/llumnix/logging/logger.py b/llumnix/logging/logger.py index becb6c61..a3491858 100644 --- a/llumnix/logging/logger.py +++ b/llumnix/logging/logger.py @@ -131,33 +131,7 @@ def _configure_llumnix_root_logger() -> None: "implies LLUMNIX_CONFIGURE_LOGGING. Please enable " "LLUMNIX_CONFIGURE_LOGGING or unset LLUMNIX_LOGGING_CONFIG_PATH.") - print(f"LLUMNIX_CONFIGURE_LOGGING: {LLUMNIX_CONFIGURE_LOGGING}") - print(f"LLUMNIX_LOG_STREAM: {LLUMNIX_LOG_STREAM}") - print(f"LLUMNIX_LOG_NODE_PATH: {LLUMNIX_LOG_NODE_PATH}") - if LLUMNIX_CONFIGURE_LOGGING: - if LLUMNIX_LOG_STREAM: - print(f"LLUMNIX_LOG_STREAM: {LLUMNIX_LOG_STREAM}") - DEFAULT_LOGGING_CONFIG["handlers"]["stream"] = { - "class": "logging.StreamHandler", - "formatter": "llumnix", - "level": LLUMNIX_LOGGING_LEVEL, - "stream": "ext://sys.stdout", - } - DEFAULT_LOGGING_CONFIG["loggers"]["llumnix"]["handlers"].append("stream") - - if LLUMNIX_LOG_NODE_PATH: - print(f"LLUMNIX_LOG_NODE_PATH: {LLUMNIX_LOG_NODE_PATH}") - DEFAULT_LOGGING_CONFIG["handlers"]["file"] = { - "class": "llumnix.logging.NodeFileHandler", - "formatter": "llumnix", - "level": LLUMNIX_LOGGING_LEVEL, - "base_path": LLUMNIX_LOG_NODE_PATH, - } - DEFAULT_LOGGING_CONFIG["loggers"]["llumnix"]["handlers"].append("file") - - print(f"DEFAULT_LOGGING_CONFIG: {DEFAULT_LOGGING_CONFIG}") - logging_config = DEFAULT_LOGGING_CONFIG if LLUMNIX_LOGGING_CONFIG_PATH: From fea7fa795e653c5f8de1da04aa427afb0122d0d4 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 17 Jan 2025 07:25:21 +0000 Subject: [PATCH 11/59] Optimize constants --- llumnix/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index d2778495..01675ec0 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -45,7 +45,6 @@ WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) from llumnix.launcher import Launcher - logger = init_logger(__name__) # TODO(s5u13b): Handle exception of ray operations. From 09519bcfd85a92219708eefab866cb94e14cb090 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 15 Jan 2025 11:09:21 +0000 Subject: [PATCH 12/59] Fix constants --- llumnix/backends/vllm/llm_engine.py | 7 ++++++- llumnix/constants.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index ce980030..8ebbd4bd 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -41,10 +41,11 @@ from llumnix.queue.utils import QueueType from llumnix.backends.utils import AsyncPutQueueActor from llumnix.utils import get_instance_name +from llumnix.constants import constants logger = init_logger(__name__) -NO_OUTPUTS_STEP_INTERVAL = 0.01 +NO_OUTPUTS_STEP_INTERVAL = constants.NO_OUTPUTS_STEP_INTERVAL class LlumnixRequestOutputFactory(RequestOutputFactory): @@ -224,6 +225,10 @@ def _process_request_outputs( self.instance_info = instance_info + for request_output in request_outputs: + if hasattr(request_output, 'request_timestamps'): + request_output.request_timestamps.engine_put_queue_timestamp = time.time() + if request_outputs: self.put_queue_args_queue.put_nowait((request_outputs, server_infos)) diff --git a/llumnix/constants.py b/llumnix/constants.py index bd174c5e..9b13e74e 100644 --- a/llumnix/constants.py +++ b/llumnix/constants.py @@ -37,6 +37,9 @@ # llumnix/llumlet/llumlet.py CHECK_ENGINE_STATE_INTERVAL: float = 1.0 +# llumnix/backends/vllm/llm_engine.py +NO_OUTPUTS_STEP_INTERVAL: float = 0.01 + # llumnix/queue/zmq_utils.py RPC_GET_DATA_TIMEOUT_MS: int = 5000 RPC_SOCKET_LIMIT_CUTOFF: int = 2000 From eec7a0dac3d30b44435a90a132288acc87e0b315 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 16 Jan 2025 10:35:51 +0000 Subject: [PATCH 13/59] Refactor --- benchmark/benchmark_serving.py | 17 ++--- llumnix/backends/utils.py | 5 +- llumnix/backends/vllm/llm_engine.py | 34 ++++----- llumnix/entrypoints/bladellm/client.py | 4 +- llumnix/entrypoints/utils.py | 16 ----- llumnix/entrypoints/vllm/api_server.py | 13 ++-- llumnix/entrypoints/vllm/client.py | 8 +-- llumnix/llumlet/llumlet.py | 4 +- llumnix/manager.py | 4 +- llumnix/metrics/base_metrics.py | 8 +-- llumnix/metrics/dumper.py | 2 + llumnix/metrics/timestamps.py | 70 +++++++++++++++++++ llumnix/metrics/variable.py | 3 + llumnix/queue/ray_queue_client.py | 10 +-- llumnix/queue/ray_queue_server.py | 11 +-- llumnix/queue/zmq_client.py | 10 +-- llumnix/queue/zmq_server.py | 11 +-- llumnix/server_info.py | 50 ------------- tests/e2e_test/utils.py | 6 +- .../unit_test/entrypoints/vllm/api_server.py | 3 +- 20 files changed, 137 insertions(+), 152 deletions(-) create mode 100644 llumnix/metrics/timestamps.py diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index d9250a0e..056e9cc7 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -372,7 +372,7 @@ def __init__(self): self._decode_sum_latencies = [] self._all_decode_token_latencies = [] self._inference_latencies = [] - self._per_token_latencies_breakdown_dict = [] + self._per_token_latency_breakdown_list = [] def measure(self, f): async def measured(*args, **kwargs): @@ -400,9 +400,10 @@ async def measured(*args, **kwargs): self._all_token_latencies.append(lat_arr) self._decode_sum_latencies.append(decode_sum_latency) self._all_decode_token_latencies.extend(lat_arr[1:,1]) - if 'per_token_latency_breakdown_dict' in output: - self._inference_latencies.append(np.mean(output['per_token_latency_breakdown_dict']['step_latency_engine'])) - self._per_token_latencies_breakdown_dict.append(output['per_token_latency_breakdown_dict']) + if 'per_token_latency_breakdown_list' in output: + step_latency = np.mean([request_timestamps['engine_step_latency'] for request_timestamps in output['per_token_latency_breakdown_list']]) + self._inference_latencies.append(step_latency) + self._per_token_latency_breakdown_list.append(output['per_token_latency_breakdown_list']) return prompt, output return measured @@ -494,7 +495,7 @@ async def benchmark( m._decode_sum_latencies, \ m._request_lens, \ m._all_decode_token_latencies, \ - m._per_token_latencies_breakdown_dict + m._per_token_latency_breakdown_list def gen_random_response_lens(distribution: str, len_mean, len_range, num_prompts): if distribution == 'uniform': @@ -785,7 +786,7 @@ def main(): decode_sum_latencies, \ request_lens, \ all_decode_token_latencies, \ - per_token_latencies_breakdown_dict = asyncio.run(benchmark( + per_token_latency_breakdown_list = asyncio.run(benchmark( backend, tokenizer, prompts, @@ -823,8 +824,8 @@ def main(): "decode_sum_latencies": decode_sum_latencies, "all_decode_token_latencies": all_decode_token_latencies, "inference_latencies": inference_latencies, - "per_token_latencies_breakdown_dict": per_token_latencies_breakdown_dict, - "throughput": throughput, + "per_token_latency_breakdown_list": per_token_latency_breakdown_list, + "throughput": throughput, "instance_num": avg_instance_num}) json.dump(results, f) diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 9aed9101..997c1043 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -27,6 +27,7 @@ from llumnix.logging.logger import init_logger from llumnix.utils import get_instance_name from llumnix.internal_config import MigrationConfig +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -55,9 +56,7 @@ async def put_nowait_to_servers(self, tasks = [] for server_id, req_outputs in server_request_outputs.items(): server_info = server_info_dict[server_id] - for req_output in req_outputs: - if hasattr(req_output, 'request_timestamps'): - req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time() + set_timestamp(req_outputs, 'engine_actor_put_queue_timestamp', time.time()) tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info))) rets = await asyncio.gather(*tasks, return_exceptions=True) for idx, ret in enumerate(rets): diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 8ebbd4bd..49d1375e 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -41,7 +41,8 @@ from llumnix.queue.utils import QueueType from llumnix.backends.utils import AsyncPutQueueActor from llumnix.utils import get_instance_name -from llumnix.constants import constants +from llumnix import constants +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -169,9 +170,7 @@ def _process_model_outputs(self, ctx.output_queue.appendleft((outputs, seq_group_metadata_list, scheduler_outputs, is_async, is_last_step, is_first_step_output, skip)) - for server_info in server_infos: - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time() + set_timestamp(server_info, 'engine_process_model_outputs_timestamp_begin', time.time()) super()._process_model_outputs(ctx, request_id) @@ -196,10 +195,8 @@ def _process_request_outputs( request_outputs = list(request_outputs) server_infos = list(server_infos) - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time - request_output.request_timestamps.engine_step_timestamp_end = time.time() + set_timestamp(request_outputs, 'engine_step_timestamp_begin', step_begin_time) + set_timestamp(request_outputs, 'engine_step_timestamp_end', time.time()) for request_output in request_outputs: if request_output.finished: @@ -225,16 +222,12 @@ def _process_request_outputs( self.instance_info = instance_info - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.engine_put_queue_timestamp = time.time() + set_timestamp(request_outputs, 'engine_put_queue_timestamp', time.time()) if request_outputs: self.put_queue_args_queue.put_nowait((request_outputs, server_infos)) - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.engine_step_postprocess_timestamp_end = time.time() + set_timestamp(request_outputs, 'engine_step_postprocess_timestamp_end', time.time()) return request_outputs, server_infos @@ -257,20 +250,17 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None: def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs): super().add_request(request_id, *args, **kwargs) seq_group = self.scheduler[0].waiting[-1] - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.engine_add_request_timestamp = time.time() + set_timestamp(server_info, 'engine_add_request_timestamp', time.time()) self.scheduler[0].waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]], - seq_group.metrics.arrival_time, seq_group.sampling_params, seq_group.lora_request, - seq_group.trace_headers, seq_group.prompt_adapter_request, seq_group.encoder_seq, - seq_group.priority) + seq_group.metrics.arrival_time, seq_group.sampling_params, seq_group.lora_request, + seq_group.trace_headers, seq_group.prompt_adapter_request, seq_group.encoder_seq, + seq_group.priority) def _start_put_queue_loop(self): while True: args = self.put_queue_args_queue.get() request_outputs, server_infos = args - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.engine_thread_put_queue_timestamp = time.time() + set_timestamp(request_outputs, 'engine_thread_put_queue_timestamp', time.time()) self._put_request_outputs_to_server(request_outputs, server_infos) def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None: diff --git a/llumnix/entrypoints/bladellm/client.py b/llumnix/entrypoints/bladellm/client.py index a3ea2314..fe95cc21 100644 --- a/llumnix/entrypoints/bladellm/client.py +++ b/llumnix/entrypoints/bladellm/client.py @@ -27,7 +27,7 @@ from blade_llm.protocol import ServerRequest, GenerateStreamResponse from blade_llm.service.communications.response import error_resp -from llumnix.server_info import RequestTimestamps +from llumnix.metrics.timestamps import RequestTimestamps from llumnix.entrypoints.utils import EntrypointsContext from llumnix.logging.logger import init_logger from llumnix.constants import WAIT_MANAGER_INTERVAL @@ -87,7 +87,7 @@ async def _manager_generate(self, request, request_id: str) -> LLMResponse: if self.llumnix_context.log_request_timestamps: # Hack request timestamps in server_info for latency breakdown. server_info_copy.request_timestamps = RequestTimestamps() - server_info_copy.request_timestamps.api_server_manager_generate_timestamp = time.time() + server_info_copy.request_timestamps.api_server_generate_timestamp = time.time() # await to catch exception await self.llumnix_context.manager.generate.remote(str(request_id), server_info_copy, server_request=request) self.llumnix_context.manager_available = True diff --git a/llumnix/entrypoints/utils.py b/llumnix/entrypoints/utils.py index 8db27b9a..936feeb6 100644 --- a/llumnix/entrypoints/utils.py +++ b/llumnix/entrypoints/utils.py @@ -74,19 +74,3 @@ async def retry_manager_method_async(ray_call, method_name, *args, **kwargs): logger.error("Manager is still unavailable after {} times retries.".format(MAX_TASK_RETRIES)) raise return ret - -def init_per_token_latency_breakdown_dict() -> Dict[str, int]: - per_token_latency_breakdown_dict = { - 'step_latency_engine': [], - 'step_postprocess_latency': [], - 'across_async_put_queue_thread_latency': [], - 'across_async_put_queue_actor_latency': [], - 'queue_rpc_latency': [], - 'background_process_get_queue_latency': [], - 'generate_benchmark_return_output_latency': [] - } - return per_token_latency_breakdown_dict - -def record_per_token_latency_breakdown(per_token_latency_breakdown_dict: Dict[str, int], request_timestamps: "RequestTimestamps"): - for key in per_token_latency_breakdown_dict.keys(): - per_token_latency_breakdown_dict[key].append(getattr(request_timestamps, key)) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 6648f5e7..e2aa700d 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -24,7 +24,6 @@ from llumnix.arg_utils import LlumnixArgumentParser, LaunchArgs from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix -from llumnix.entrypoints.utils import init_per_token_latency_breakdown_dict, record_per_token_latency_breakdown from llumnix.entrypoints.vllm.arg_utils import add_cli_args, get_args from llumnix.entrypoints.vllm.client import LlumnixClientVLLM from llumnix.logging.logger import init_logger @@ -33,6 +32,7 @@ from llumnix.backends.backend_interface import BackendType from llumnix.entrypoints.utils import LaunchMode, is_gpu_available from llumnix.constants import SERVER_TIMEOUT_KEEP_ALIVE +from llumnix.metrics.timestamps import set_timestamp # Code file with __main__ should set the logger name to inherit the llumnix logger configuration. logger = init_logger("llumnix.entrypoints.vllm.api_server") @@ -126,8 +126,8 @@ async def generate_benchmark(request: Request) -> Response: # Non-streaming case final_output = None per_token_latency = [] - per_token_latency_breakdown_dict = init_per_token_latency_breakdown_dict() - async for request_output in results_generator.generator(): + per_token_latency_breakdown_list = [] + async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. await llumnix_client.abort(request_id) @@ -136,9 +136,9 @@ async def generate_benchmark(request: Request) -> Response: per_token_latency.append([now, (now - start)*1000]) start = now final_output = request_output + set_timestamp(request_output, 'api_server_generate_timestamp_end', now) if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.api_server_generate_benchmark_timestamp_end = now - record_per_token_latency_breakdown(per_token_latency_breakdown_dict, request_output.request_timestamps) + per_token_latency_breakdown_list.append(request_output.request_timestamps.to_latency_breakdown_dict()) assert final_output is not None if llumnix_client.log_requests: @@ -158,8 +158,9 @@ async def generate_benchmark(request: Request) -> Response: 'generated_text': generation, 'num_output_tokens_cf': num_output_tokens, 'per_token_latency': per_token_latency, - 'per_token_latency_breakdown_dict': per_token_latency_breakdown_dict } + if per_token_latency_breakdown_list: + ret['per_token_latency_breakdown_list'] = per_token_latency_breakdown_list return JSONResponse(ret) diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index c8fad376..345f985a 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -10,7 +10,7 @@ from llumnix.logging.logger import init_logger from llumnix.entrypoints.utils import EntrypointsContext -from llumnix.server_info import RequestTimestamps +from llumnix.metrics.timestamps import RequestTimestamps, set_timestamp from llumnix.queue.queue_server_base import QueueServerBase from llumnix.server_info import ServerInfo from llumnix.manager import Manager @@ -74,7 +74,7 @@ async def _generate_by_manager(self, if self.log_request_timestamps: # Hack request timestamps in server_info for latency breakdown. server_info.request_timestamps = RequestTimestamps() - server_info.request_timestamps.api_server_manager_generate_timestamp = time.time() + server_info.request_timestamps.api_server_generate_timestamp = time.time() await self.manager.generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs) async def _generate_by_instance(self, @@ -128,9 +128,7 @@ async def is_ready(self) -> bool: async def get_request_outputs_loop(self): while True: request_outputs = await self.request_output_queue.get() - for request_output in request_outputs: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.api_server_background_process_get_queue_timestamp = time.time() + set_timestamp(request_outputs, 'api_server_get_queue_timestamp', time.time()) for request_output in request_outputs: request_id = request_output.request_id # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 185aea64..5aab64dd 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -33,6 +33,7 @@ from llumnix.arg_utils import InstanceArgs from llumnix.utils import get_instance_name from llumnix.constants import CHECK_ENGINE_STATE_INTERVAL +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -213,8 +214,7 @@ def get_all_request_ids(self) -> List[str]: return self.backend_engine.get_all_request_ids() def generate(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None: - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.llumlet_generate_timestamp = time.time() + set_timestamp(server_info, 'llumlet_generate_timestamp', time.time()) self.backend_engine.add_request(request_id, server_info, expected_steps, *args, **kwargs) def abort(self, request_id: Union[str, Iterable[str]]) -> None: diff --git a/llumnix/manager.py b/llumnix/manager.py index 01675ec0..84db8bf3 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -44,6 +44,7 @@ WAIT_PLACEMENT_GROUP_TIMEOUT, CHECK_DEPLOYMENT_STATES_INTERVAL, WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) from llumnix.launcher import Launcher +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -153,8 +154,7 @@ async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwar instance_id, request_expected_steps = self.global_scheduler.dispatch() try: - if hasattr(server_info, 'request_timestamps'): - server_info.request_timestamps.manager_generate_timestamp = time.time() + set_timestamp(server_info, 'manager_generate_timestamp', time.time()) await self.instances[instance_id].generate.remote(request_id, server_info, request_expected_steps, *args, **kwargs) if self.log_requests: logger.info("manager receive request {}".format(request_id)) diff --git a/llumnix/metrics/base_metrics.py b/llumnix/metrics/base_metrics.py index 534d6df1..a51e0a95 100644 --- a/llumnix/metrics/base_metrics.py +++ b/llumnix/metrics/base_metrics.py @@ -15,6 +15,7 @@ from llumnix.metrics.variable import _REGISTRY, Status from llumnix.metrics.dumper import Dumper, DummyDumper + from llumnix.instance_info import InstanceInfo @@ -30,8 +31,7 @@ def __init__(self): self.num_running_requests = Status("num_running_requests") self.num_waiting_requests = Status("num_waiting_requests") - self.dumper: Dumper = None - self._init_dumper() + self.dumper: Dumper = self._init_dumper() def dump(self): self.dumper.dump(_REGISTRY.describe_all()) @@ -40,7 +40,8 @@ def to_instance_info(self) -> InstanceInfo: return InstanceInfo(**(_REGISTRY.describe_all())) def _init_dumper(self,): - self.dumper = DummyDumper() + dumper = DummyDumper() + return dumper @abstractmethod def block_manager_init_metrics(self, block_manager): @@ -57,4 +58,3 @@ def scheduler_step_metrics(self, scheduler): @abstractmethod def engine_step_metrics(self, scheduler): ... - \ No newline at end of file diff --git a/llumnix/metrics/dumper.py b/llumnix/metrics/dumper.py index 47430721..c7c0216e 100644 --- a/llumnix/metrics/dumper.py +++ b/llumnix/metrics/dumper.py @@ -24,10 +24,12 @@ class Dumper(ABC): def dump(self, metrics: Dict[str, Any]) -> None: ... + class LoggerDumper(Dumper): def dump(self, metrics: Dict[str, Any]) -> None: logger.info("Metrics: {}".format(metrics)) + class DummyDumper(Dumper): def dump(self, metrics: Dict[str, Any]) -> None: pass diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py new file mode 100644 index 00000000..09ca8e1e --- /dev/null +++ b/llumnix/metrics/timestamps.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Union, Dict, Any, Iterable + +def set_timestamp(obj: Any, timestamp_attr: str, timestamp: float): + if not isinstance(obj, Iterable): + obj = [obj,] + objs = list(obj) + for item in objs: + if hasattr(item, "request_timestamps"): + if hasattr(item.request_timestamps, timestamp_attr): + setattr(item.request_timestamps, timestamp_attr, timestamp) + + +@dataclass +class RequestTimestamps: + api_server_generate_timestamp: float = field(default=0.0) + manager_generate_timestamp: float = field(default=0.0) + llumlet_generate_timestamp: float = field(default=0.0) + engine_add_request_timestamp: float = field(default=0.0) + engine_process_model_outputs_timestamp_begin: float = field(default=0.0) + engine_process_model_outputs_timestamp_end: float = field(default=0.0) + engine_step_timestamp_begin: float = field(default=0.0) + engine_step_timestamp_end: float = field(default=0.0) + engine_step_postprocess_timestamp_end: float = field(default=0.0) + engine_put_queue_timestamp: float = field(default=0.0) + engine_thread_put_queue_timestamp: float = field(default=0.0) + engine_actor_put_queue_timestamp: float = field(default=0.0) + queue_client_send_timestamp: float = field(default=0.0) + queue_server_receive_timestamp: float = field(default=0.0) + api_server_get_queue_timestamp: float = field(default=0.0) + api_server_generate_timestamp_end: float = field(default=0.0) + + def to_latency_breakdown_dict(self) -> Dict[str, Union[float, List[float]]]: + latency_dict = { + "across_manager_latency": (self.manager_generate_timestamp - self.api_server_generate_timestamp) * 1000, + "across_llumlet_latency": (self.llumlet_generate_timestamp - self.manager_generate_timestamp) * 1000, + "across_engine_latency": (self.engine_add_request_timestamp - self.llumlet_generate_timestamp) * 1000, + "process_model_outputs_latency": + (self.engine_process_model_outputs_timestamp_end - self.engine_process_model_outputs_timestamp_begin) * 1000, + "engine_step_latency": + (self.engine_step_timestamp_end - self.engine_step_timestamp_begin) * 1000, + "step_postprocess_latency": + (self.engine_step_postprocess_timestamp_end - self.engine_step_timestamp_end) * 1000, + "across_async_put_queue_thread_latency": + (self.engine_thread_put_queue_timestamp - self.engine_put_queue_timestamp) * 1000, + "across_async_put_queue_actor_latency": + (self.engine_actor_put_queue_timestamp - self.engine_thread_put_queue_timestamp) * 1000, + "across_queue_client_latency": + (self.queue_client_send_timestamp - self.engine_actor_put_queue_timestamp) * 1000, + "queue_rpc_latency": + (self.queue_server_receive_timestamp - self.queue_client_send_timestamp) * 1000, + "api_server_get_queue_latency": + (self.api_server_get_queue_timestamp - self.queue_server_receive_timestamp) * 1000, + "across_results_generator_latency": + (self.api_server_generate_timestamp_end - self.api_server_get_queue_timestamp) * 1000, + } + return latency_dict diff --git a/llumnix/metrics/variable.py b/llumnix/metrics/variable.py index d964c97b..80901ae1 100644 --- a/llumnix/metrics/variable.py +++ b/llumnix/metrics/variable.py @@ -39,8 +39,10 @@ def clear(self): def remove(self, key) -> None: del self._metrics[key] + _REGISTRY = Registery() + class Variable(ABC): def __init__(self, name: str): self._name: str = name @@ -61,6 +63,7 @@ def describe(self): def name(self) -> str: return self._name + class Status(Variable): def __init__(self, name: str, initial_value: Any = None): super().__init__(name) diff --git a/llumnix/queue/ray_queue_client.py b/llumnix/queue/ray_queue_client.py index d4eb9586..ff9182af 100644 --- a/llumnix/queue/ray_queue_client.py +++ b/llumnix/queue/ray_queue_client.py @@ -17,20 +17,16 @@ from llumnix.server_info import ServerInfo from llumnix.queue.queue_client_base import QueueClientBase +from llumnix.metrics.timestamps import set_timestamp class RayQueueClient(QueueClientBase): async def put_nowait(self, item: Any, server_info: ServerInfo): output_queue = server_info.request_output_queue - if isinstance(item, Iterable): - for request_output in item: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_client_send_timestamp = time.time() + set_timestamp(item, 'queue_client_send_timestamp', time.time()) return await output_queue.actor.put_nowait.remote(item) async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): output_queue = server_info.request_output_queue - for request_output in items: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_client_send_timestamp = time.time() + set_timestamp(items, 'queue_client_send_timestamp', time.time()) return await output_queue.actor.put_nowait_batch.remote(items) diff --git a/llumnix/queue/ray_queue_server.py b/llumnix/queue/ray_queue_server.py index b8648157..f5d32873 100644 --- a/llumnix/queue/ray_queue_server.py +++ b/llumnix/queue/ray_queue_server.py @@ -11,13 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable import time import ray from ray.util.queue import Queue as RayQueue from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from llumnix.queue.queue_server_base import QueueServerBase +from llumnix.metrics.timestamps import set_timestamp class RayQueueServer(QueueServerBase): @@ -34,18 +34,13 @@ def __init__(self) -> None: async def get(self): item = await self.queue.actor.get.remote() - if isinstance(item, Iterable): - for request_output in item: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_server_receive_timestamp = time.time() + set_timestamp(item, 'queue_server_receive_timestamp', time.time()) return item async def get_nowait_batch(self): qsize = await self.queue.actor.qsize.remote() items = await self.queue.actor.get_nowait_batch.remote(qsize) - for request_output in items: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_server_receive_timestamp = time.time() + set_timestamp(items, 'queue_server_receive_timestamp', time.time()) return items async def run_server_loop(self): diff --git a/llumnix/queue/zmq_client.py b/llumnix/queue/zmq_client.py index a6dc4452..cdec50ab 100644 --- a/llumnix/queue/zmq_client.py +++ b/llumnix/queue/zmq_client.py @@ -27,6 +27,7 @@ RPCUtilityRequest, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, get_open_zmq_ipc_path) from llumnix.constants import RPC_GET_DATA_TIMEOUT_MS, RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -109,10 +110,7 @@ async def wait_for_server_rpc(self, async def put_nowait(self, item: Any, server_info: ServerInfo): rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) - if isinstance(item, Iterable): - for request_output in item: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_client_send_timestamp = time.time() + set_timestamp(item, 'queue_client_send_timestamp', time.time()) await self._send_one_way_rpc_request( request=RPCPutNoWaitQueueRequest(item=item), rpc_path=rpc_path, @@ -120,9 +118,7 @@ async def put_nowait(self, item: Any, server_info: ServerInfo): async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) - for request_output in items: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_client_send_timestamp = time.time() + set_timestamp(items, 'queue_client_send_timestamp', time.time()) await self._send_one_way_rpc_request( request=RPCPutNoWaitBatchQueueRequest(items=items), rpc_path=rpc_path, diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index f02b1e9c..7180718f 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -14,7 +14,6 @@ import asyncio import time from typing import (Coroutine, Any) -from collections.abc import Iterable from typing_extensions import Never import zmq @@ -25,6 +24,7 @@ RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) from llumnix.logging.logger import init_logger from llumnix.constants import RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -128,10 +128,7 @@ async def _is_server_ready(self, identity): async def _put_nowait(self, identity, put_nowait_queue_request: RPCPutNoWaitQueueRequest): try: item = put_nowait_queue_request.item - if isinstance(item, Iterable): - for request_output in item: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_server_receive_timestamp = time.time() + set_timestamp(item, 'queue_server_receive_timestamp', time.time()) self.put_nowait(item) await self.socket.send_multipart( [identity, cloudpickle.dumps(RPC_SUCCESS_STR)]) @@ -142,9 +139,7 @@ async def _put_nowait(self, identity, put_nowait_queue_request: RPCPutNoWaitQueu async def _put_nowait_batch(self, identity, put_nowait_batch_queue_request: RPCPutNoWaitBatchQueueRequest): try: items = put_nowait_batch_queue_request.items - for request_output in items: - if hasattr(request_output, 'request_timestamps'): - request_output.request_timestamps.queue_server_receive_timestamp = time.time() + set_timestamp(items, 'queue_server_receive_timestamp', time.time()) self.put_nowait_batch(items) await self.socket.send_multipart( [identity, cloudpickle.dumps(RPC_SUCCESS_STR)]) diff --git a/llumnix/server_info.py b/llumnix/server_info.py index dc1dcba4..74146607 100644 --- a/llumnix/server_info.py +++ b/llumnix/server_info.py @@ -15,56 +15,6 @@ from llumnix.queue.queue_type import QueueType -class RequestTimestamps: - def __init__(self): - self.api_server_manager_generate_timestamp = -1.0 - self.manager_generate_timestamp = -1.0 - self.llumlet_generate_timestamp = -1.0 - self.engine_add_request_timestamp = -1.0 - self.engine_process_model_outputs_timestamp_begin = -1.0 - self.engine_process_model_outputs_timestamp_end = -1.0 - self.engine_step_timestamp_begin = -1.0 - self.engine_step_timestamp_end = -1.0 - self.engine_step_postprocess_timestamp_end = -1.0 - self.engine_thread_put_queue_timestamp = -1.0 - self.engine_actor_put_queue_timestamp = -1.0 - self.queue_client_send_timestamp = -1.0 - self.queue_server_receive_timestamp = -1.0 - self.api_server_background_process_get_queue_timestamp = -1.0 - self.api_server_generate_benchmark_timestamp_end = -1.0 - - @property - def process_model_outputs_latency(self): - return (self.engine_process_model_outputs_timestamp_end - self.engine_process_model_outputs_timestamp_begin)*1000 - - @property - def step_latency_engine(self): - return (self.engine_step_timestamp_end - self.engine_step_timestamp_begin)*1000 - - @property - def step_postprocess_latency(self): - return (self.engine_step_postprocess_timestamp_end - self.engine_step_timestamp_end)*1000 - - @property - def across_async_put_queue_thread_latency(self): - return (self.engine_thread_put_queue_timestamp - self.engine_step_timestamp_end)*1000 - - @property - def across_async_put_queue_actor_latency(self): - return (self.engine_actor_put_queue_timestamp - self.engine_thread_put_queue_timestamp)*1000 - - @property - def queue_rpc_latency(self): - return (self.queue_server_receive_timestamp - self.queue_client_send_timestamp)*1000 - - @property - def background_process_get_queue_latency(self): - return (self.api_server_background_process_get_queue_timestamp - self.queue_server_receive_timestamp)*1000 - - @property - def generate_benchmark_return_output_latency(self): - return (self.api_server_generate_benchmark_timestamp_end - self.api_server_background_process_get_queue_timestamp)*1000 - class ServerInfo: def __init__(self, server_id: str, diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 27500a3a..06fc352c 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -28,6 +28,7 @@ def generate_launch_command(result_filename: str = "", model = "facebook/opt-125m", max_model_len: int = 4096, log_instance_info: bool = False, + log_request_timestamps: bool = False, request_migration_policy: str = 'SR', max_num_batched_tokens: int = 16000, enable_pd_disagg: bool = False, @@ -40,6 +41,7 @@ def generate_launch_command(result_filename: str = "", f"--initial-instances {instances_num} " f"{'--log-filename manager ' if log_instance_info else ''}" f"{'--log-instance-info ' if log_instance_info else ''}" + f"{'--log-request-timestamps ' if log_request_timestamps else ''}" f"--enable-migration " f"--model {model} " f"--worker-use-ray " @@ -66,7 +68,8 @@ def generate_serve_command(result_filename: str = "", migration_backend = "gloo", model = "facebook/opt-125m", max_model_len: int = 4096, - log_instance_info: bool = False, + log_instance_info: bool = True, + log_request_timestamps: bool = True, request_migration_policy: str = 'SR', max_num_batched_tokens: int = 16000, enable_pd_disagg: bool = False, @@ -78,6 +81,7 @@ def generate_serve_command(result_filename: str = "", f"--port {port} " f"{'--log-filename manager ' if log_instance_info else ''}" f"{'--log-instance-info ' if log_instance_info else ''}" + f"{'--log-request-timestamps ' if log_request_timestamps else ''}" f"--enable-migration " f"--model {model} " f"--worker-use-ray " diff --git a/tests/unit_test/entrypoints/vllm/api_server.py b/tests/unit_test/entrypoints/vllm/api_server.py index 542a7f81..66dec9a9 100644 --- a/tests/unit_test/entrypoints/vllm/api_server.py +++ b/tests/unit_test/entrypoints/vllm/api_server.py @@ -19,7 +19,8 @@ import llumnix.entrypoints.vllm.api_server import llumnix.manager -from llumnix.server_info import ServerInfo, RequestTimestamps +from llumnix.server_info import ServerInfo +from llumnix.metrics.timestamps import RequestTimestamps from llumnix.utils import random_uuid, get_manager_name from llumnix.queue.utils import init_request_output_queue_server, init_request_output_queue_client, QueueType from llumnix.entrypoints.utils import EntrypointsContext From 425c543d22a1a6b502b38eb3f32771a3e8a050d2 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 16 Jan 2025 11:18:50 +0000 Subject: [PATCH 14/59] Fix benchmark_serving --- benchmark/benchmark_serving.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index 056e9cc7..39423eb3 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -400,9 +400,8 @@ async def measured(*args, **kwargs): self._all_token_latencies.append(lat_arr) self._decode_sum_latencies.append(decode_sum_latency) self._all_decode_token_latencies.extend(lat_arr[1:,1]) + self._inference_latencies.append(0.0) if 'per_token_latency_breakdown_list' in output: - step_latency = np.mean([request_timestamps['engine_step_latency'] for request_timestamps in output['per_token_latency_breakdown_list']]) - self._inference_latencies.append(step_latency) self._per_token_latency_breakdown_list.append(output['per_token_latency_breakdown_list']) return prompt, output return measured From 1361bf6d3692a5fc6b053ffc7fae2b55362ab73c Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 16 Jan 2025 11:37:34 +0000 Subject: [PATCH 15/59] Minors --- llumnix/entrypoints/vllm/client.py | 3 +-- llumnix/metrics/timestamps.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index 345f985a..0b10fa44 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -21,8 +21,7 @@ class LlumnixClientVLLM: - def __init__(self, - entrypoints_context: EntrypointsContext): + def __init__(self, entrypoints_context: EntrypointsContext): self.manager: Manager = entrypoints_context.manager self.instances: Dict[str, Llumlet] = entrypoints_context.instances self.request_output_queue: QueueServerBase = entrypoints_context.request_output_queue diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py index 09ca8e1e..2a5ec000 100644 --- a/llumnix/metrics/timestamps.py +++ b/llumnix/metrics/timestamps.py @@ -64,7 +64,7 @@ def to_latency_breakdown_dict(self) -> Dict[str, Union[float, List[float]]]: (self.queue_server_receive_timestamp - self.queue_client_send_timestamp) * 1000, "api_server_get_queue_latency": (self.api_server_get_queue_timestamp - self.queue_server_receive_timestamp) * 1000, - "across_results_generator_latency": + "across_request_streams_latency": (self.api_server_generate_timestamp_end - self.api_server_get_queue_timestamp) * 1000, } return latency_dict From 6d90be288c1f2eec73f0d637205293a6f52204a1 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 17 Jan 2025 02:05:12 +0000 Subject: [PATCH 16/59] Minors --- llumnix/metrics/timestamps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py index 2a5ec000..1dc146eb 100644 --- a/llumnix/metrics/timestamps.py +++ b/llumnix/metrics/timestamps.py @@ -43,7 +43,7 @@ class RequestTimestamps: api_server_get_queue_timestamp: float = field(default=0.0) api_server_generate_timestamp_end: float = field(default=0.0) - def to_latency_breakdown_dict(self) -> Dict[str, Union[float, List[float]]]: + def to_latency_breakdown_dict(self) -> Dict[str, float]: latency_dict = { "across_manager_latency": (self.manager_generate_timestamp - self.api_server_generate_timestamp) * 1000, "across_llumlet_latency": (self.llumlet_generate_timestamp - self.manager_generate_timestamp) * 1000, From e7bdcd3d0f524632f61c26a458891b43f89ce899 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 17 Jan 2025 10:01:47 +0000 Subject: [PATCH 17/59] Add poll instance infos and migration tasks log --- llumnix/manager.py | 12 ++++--- llumnix/metrics/timestamps.py | 32 +++++++++---------- .../global_scheduler/test_manager.py | 2 +- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 84db8bf3..3ce46a18 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -135,7 +135,7 @@ def __init__(self, # tasks # When manager starts, it automatically connects to all existing instances. run_async_func_sync(self._connect_to_instances()) - asyncio.create_task(self._update_instance_info_loop(self.polling_interval)) + asyncio.create_task(self._poll_instance_info_loop(self.polling_interval)) asyncio.create_task(self._clear_request_instance_loop(CLEAR_REQUEST_INSTANCE_INTERVAL)) if hasattr(self, "launch_mode") and self.launch_mode == LaunchMode.GLOBAL: @@ -195,8 +195,8 @@ def abort_done_callback(instance_id: str, request_ids: List[str], fut): tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) - async def _update_instance_info_loop(self, interval: float) -> None: - def update_instance_info_done_callback(instance_id: str, fut): + async def _poll_instance_info_loop(self, interval: float) -> None: + def get_instance_info_done_callback(instance_id: str, fut): ret = fut.result()[0] if not isinstance(ret, ray.exceptions.RayActorError): if ret is not None: @@ -214,9 +214,11 @@ def update_instance_info_done_callback(instance_id: str, fut): for instance_id, instance in self.instances.items(): # Use asyncio.gather to wrap ray remote call to add done callback, asyncio.create_task will get error. task = asyncio.gather(instance.get_instance_info.remote(), return_exceptions=True) - task.add_done_callback(partial(update_instance_info_done_callback, instance_id)) + task.add_done_callback(partial(get_instance_info_done_callback, instance_id)) tasks.append(task) + logger.debug("Polling instance infos of all instances starts.") await asyncio.gather(*tasks, return_exceptions=True) + logger.debug("Polling instance infos of all instances ends.") self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ @@ -284,7 +286,9 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - return_exceptions=True) task.add_done_callback(partial(migrate_done_callback_wrapper, migrate_instance_pair)) migration_tasks.append(task) + logger.info("{} migration tasks starts.".format(len(migration_tasks))) await asyncio.gather(*migration_tasks, return_exceptions=True) + logger.info("{} migration tasks ends.".format(len(migration_tasks)) # pylint: disable=W0703 except Exception as e: logger.error("Unexpected exception: {}".format(e)) diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py index 1dc146eb..b21d2434 100644 --- a/llumnix/metrics/timestamps.py +++ b/llumnix/metrics/timestamps.py @@ -26,22 +26,22 @@ def set_timestamp(obj: Any, timestamp_attr: str, timestamp: float): @dataclass class RequestTimestamps: - api_server_generate_timestamp: float = field(default=0.0) - manager_generate_timestamp: float = field(default=0.0) - llumlet_generate_timestamp: float = field(default=0.0) - engine_add_request_timestamp: float = field(default=0.0) - engine_process_model_outputs_timestamp_begin: float = field(default=0.0) - engine_process_model_outputs_timestamp_end: float = field(default=0.0) - engine_step_timestamp_begin: float = field(default=0.0) - engine_step_timestamp_end: float = field(default=0.0) - engine_step_postprocess_timestamp_end: float = field(default=0.0) - engine_put_queue_timestamp: float = field(default=0.0) - engine_thread_put_queue_timestamp: float = field(default=0.0) - engine_actor_put_queue_timestamp: float = field(default=0.0) - queue_client_send_timestamp: float = field(default=0.0) - queue_server_receive_timestamp: float = field(default=0.0) - api_server_get_queue_timestamp: float = field(default=0.0) - api_server_generate_timestamp_end: float = field(default=0.0) + api_server_generate_timestamp: float = 0.0 + manager_generate_timestamp: float = 0.0 + llumlet_generate_timestamp: float = 0.0 + engine_add_request_timestamp: float = 0.0 + engine_process_model_outputs_timestamp_begin: float = 0.0 + engine_process_model_outputs_timestamp_end: float = 0.0 + engine_step_timestamp_begin: float = 0.0 + engine_step_timestamp_end: float = 0.0 + engine_step_postprocess_timestamp_end: float = 0.0 + engine_put_queue_timestamp: float = 0.0 + engine_thread_put_queue_timestamp: float = 0.0 + engine_actor_put_queue_timestamp: float = 0.0 + queue_client_send_timestamp: float = 0.0 + queue_server_receive_timestamp: float = 0.0 + api_server_get_queue_timestamp: float = 0.0 + api_server_generate_timestamp_end: float = 0.0 def to_latency_breakdown_dict(self) -> Dict[str, float]: latency_dict = { diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 3d8e06f9..104a177f 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -321,7 +321,7 @@ def get_instance_info_migrate_out(instance_id): ) return instance_info -def test_update_instance_info_loop_and_migrate(ray_env, manager): +def test_poll_instance_info_loop_and_migrate(ray_env, manager): num_instances = 5 instance_ids, instances = init_instances(num_instances) From ac2d33c6ca82c73cbc72ea86f8a31ba97b03e420 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 03:12:45 +0000 Subject: [PATCH 18/59] Minors --- llumnix/manager.py | 2 +- llumnix/metrics/timestamps.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 3ce46a18..b7417b47 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -288,7 +288,7 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - migration_tasks.append(task) logger.info("{} migration tasks starts.".format(len(migration_tasks))) await asyncio.gather(*migration_tasks, return_exceptions=True) - logger.info("{} migration tasks ends.".format(len(migration_tasks)) + logger.info("{} migration tasks ends.".format(len(migration_tasks))) # pylint: disable=W0703 except Exception as e: logger.error("Unexpected exception: {}".format(e)) diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py index b21d2434..9a6320a1 100644 --- a/llumnix/metrics/timestamps.py +++ b/llumnix/metrics/timestamps.py @@ -11,8 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field -from typing import List, Union, Dict, Any, Iterable +from dataclasses import dataclass +from typing import Dict, Any, Iterable def set_timestamp(obj: Any, timestamp_attr: str, timestamp: float): if not isinstance(obj, Iterable): From a86ae924f4918b08c5e3914942e3ef252bfff2e8 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 03:13:34 +0000 Subject: [PATCH 19/59] Minors --- llumnix/metrics/timestamps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llumnix/metrics/timestamps.py b/llumnix/metrics/timestamps.py index 9a6320a1..29a78166 100644 --- a/llumnix/metrics/timestamps.py +++ b/llumnix/metrics/timestamps.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import Dict, Any, Iterable + def set_timestamp(obj: Any, timestamp_attr: str, timestamp: float): if not isinstance(obj, Iterable): obj = [obj,] From a45721ba5bb192acc21f3ffc5deb7702846f3a6e Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 03:46:17 +0000 Subject: [PATCH 20/59] Minors --- llumnix/manager.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index b7417b47..afcf738d 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -216,9 +216,11 @@ def get_instance_info_done_callback(instance_id: str, fut): task = asyncio.gather(instance.get_instance_info.remote(), return_exceptions=True) task.add_done_callback(partial(get_instance_info_done_callback, instance_id)) tasks.append(task) - logger.debug("Polling instance infos of all instances starts.") + if self.num_instance_info_updates % 100 == 0: + logger.debug("Polling instance infos of all instances starts.") await asyncio.gather(*tasks, return_exceptions=True) - logger.debug("Polling instance infos of all instances ends.") + if self.num_instance_info_updates % 100 == 0: + logger.debug("Polling instance infos of all instances ends.") self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ @@ -286,9 +288,11 @@ def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) - return_exceptions=True) task.add_done_callback(partial(migrate_done_callback_wrapper, migrate_instance_pair)) migration_tasks.append(task) - logger.info("{} migration tasks starts.".format(len(migration_tasks))) + if len(migration_tasks) > 0: + logger.info("{} migration tasks starts.".format(len(migration_tasks))) await asyncio.gather(*migration_tasks, return_exceptions=True) - logger.info("{} migration tasks ends.".format(len(migration_tasks))) + if len(migration_tasks) > 0: + logger.info("{} migration tasks ends.".format(len(migration_tasks))) # pylint: disable=W0703 except Exception as e: logger.error("Unexpected exception: {}".format(e)) From 7936ff09f6fbb3b5addafdf48322b23ebd5ff7fb Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 04:24:54 +0000 Subject: [PATCH 21/59] Fix --- llumnix/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index afcf738d..8ff7c8a1 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -263,7 +263,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> logger.info("Instance {} is dead.".format(instance_id)) self.scale_down(instance_id) else: - migrate_out_request_ids = ret[0] + migrate_out_request_ids = ret if migrate_out_request_ids: migrate_out_request_id = migrate_out_request_ids[0] self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] From e59cbc78564fcd00f59ff9375bcba40ebf2b404f Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 08:24:42 +0000 Subject: [PATCH 22/59] Minors --- llumnix/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 8ff7c8a1..b64de7c3 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -217,10 +217,10 @@ def get_instance_info_done_callback(instance_id: str, fut): task.add_done_callback(partial(get_instance_info_done_callback, instance_id)) tasks.append(task) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling instance infos of all instances starts.") + logger.debug("Polling {} instance infos of all instances starts.".format(len(self.num_instances))) await asyncio.gather(*tasks, return_exceptions=True) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling instance infos of all instances ends.") + logger.debug("Polling {} instance infos of all instances ends.".format(len(self.num_instances))) self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ From 22b7561174abae395a57c4187cc06fdc42150532 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 08:26:22 +0000 Subject: [PATCH 23/59] Fix --- llumnix/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index b64de7c3..6abff81a 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -217,10 +217,10 @@ def get_instance_info_done_callback(instance_id: str, fut): task.add_done_callback(partial(get_instance_info_done_callback, instance_id)) tasks.append(task) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling {} instance infos of all instances starts.".format(len(self.num_instances))) + logger.debug("Polling {} instance infos of all instances starts.".format(self.num_instances)) await asyncio.gather(*tasks, return_exceptions=True) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling {} instance infos of all instances ends.".format(len(self.num_instances))) + logger.debug("Polling {} instance infos of all instances ends.".format(self.num_instances)) self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ From 63862f93c2df2f7253b2e43d77eef8cef2bb7d1d Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 09:02:10 +0000 Subject: [PATCH 24/59] Minors --- llumnix/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 6abff81a..0983fbad 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -217,10 +217,10 @@ def get_instance_info_done_callback(instance_id: str, fut): task.add_done_callback(partial(get_instance_info_done_callback, instance_id)) tasks.append(task) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling {} instance infos of all instances starts.".format(self.num_instances)) + logger.debug("Polling instance infos of {} instances starts.".format(self.num_instances)) await asyncio.gather(*tasks, return_exceptions=True) if self.num_instance_info_updates % 100 == 0: - logger.debug("Polling {} instance infos of all instances ends.".format(self.num_instances)) + logger.debug("Polling instance infos of {} instances ends.".format(self.num_instances)) self.num_instance_info_updates += 1 # Push migrate when the instance_info have updated a certain number of times. if self.enable_migration and self.num_instance_info_updates != 0 \ From 30fa6f0ce5ac26e98d061d032e114bb5959d9614 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Mon, 20 Jan 2025 09:52:51 +0000 Subject: [PATCH 25/59] Fix --- llumnix/manager.py | 2 +- tests/unit_test/global_scheduler/test_manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 0983fbad..1edbcf83 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -270,7 +270,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> logger.info("instance {}->{} migrate done, migrate request {}".format( migrate_instance_pair[0], migrate_instance_pair[1], migrate_out_request_ids)) def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -> None: - ret = fut.result() + ret = fut.result()[0] loop = asyncio.get_event_loop() loop.create_task(migrate_done_callback(ret, migrate_instance_pair)) diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 104a177f..009947be 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -94,7 +94,7 @@ def migrate_out(self, dst_instance_name): migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name)) time.sleep(0.1) - return self.num_migrate_out + return [] def migrate_in(self, src_instance_name): self.num_migrate_in += 1 From 29821e960d0d2b175bad5ab0e1d0580b7b84d334 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 21 Jan 2025 06:06:31 +0000 Subject: [PATCH 26/59] Minors --- llumnix/llumlet/llumlet.py | 2 +- llumnix/manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 5aab64dd..ba087a73 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -184,7 +184,7 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds if status == MigrationStatus.ABORTED_SRC: await migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id) t1 = time.time() - logger.info("{}->{} migrate done, migrate request {}, migration status: {}, len: {} blocks, cost: {} ms" \ + logger.info("Instance {}->{} migrate done, migrate request {}, migration status: {}, len: {} blocks, cost: {} ms" \ .format(self.instance_id, dst_instance_id, migrated_request, status, \ sum(migrate_out_request.stage_num_blocks_list), (t1 - t0)*1000)) except ray.exceptions.RayActorError: diff --git a/llumnix/manager.py b/llumnix/manager.py index 1edbcf83..4b889bd5 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -267,7 +267,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> if migrate_out_request_ids: migrate_out_request_id = migrate_out_request_ids[0] self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] - logger.info("instance {}->{} migrate done, migrate request {}".format( + logger.info("Instance {}->{} migrate done, migrate request {}".format( migrate_instance_pair[0], migrate_instance_pair[1], migrate_out_request_ids)) def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -> None: ret = fut.result()[0] From ca7aa5bdcb70c23d3a47d4d72b590ab6a4436ab0 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 21 Jan 2025 09:28:08 +0000 Subject: [PATCH 27/59] Reorg simulator files --- docs/Simulator.md | 2 +- llumnix/backends/vllm/executor.py | 68 +------------- llumnix/backends/vllm/llm_engine.py | 3 +- llumnix/backends/vllm/sim_executor.py | 93 +++++++++++++++++++ .../vllm/{simulator.py => sim_llm_engine.py} | 1 + .../backends/vllm/test_llm_engine.py | 3 +- .../unit_test/backends/vllm/test_simulator.py | 4 +- .../global_scheduler/test_manager.py | 8 +- 8 files changed, 106 insertions(+), 76 deletions(-) create mode 100644 llumnix/backends/vllm/sim_executor.py rename llumnix/backends/vllm/{simulator.py => sim_llm_engine.py} (99%) diff --git a/docs/Simulator.md b/docs/Simulator.md index 687de62c..1bfec10f 100644 --- a/docs/Simulator.md +++ b/docs/Simulator.md @@ -3,7 +3,7 @@ Llumnix can generate latency data from logs. After run a real benchmark with `-- After running profiling with `python llumnix.backends.profiling.py`. You can get a `$PROFILING_RESULT_FILE_PATH.pkl` -Then, you can run simulator with `--profiling-result-file-path PROFILING_RESULT_FILE_PATH`. +Then, you can run simulator with `--simulator-mode` and `--profiling-result-file-path PROFILING_RESULT_FILE_PATH`. ``` diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 820efc4e..c8418b59 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -12,10 +12,9 @@ # limitations under the License. import time -import asyncio from collections import defaultdict -from typing import Callable, Dict, List, Optional, Tuple, Type +from typing import List, Optional import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy # pylint: disable=unused-import @@ -263,68 +262,3 @@ async def execute_model_async(self, *args, **kwargs): t1 = time.time() self.last_inference_latency = (t1 - t0) * 1000 return outputs - -class SimGPUExecutor(RayGPUExecutor): - latency_mem: LatencyMemData = None - def __init__(self, *args, **kwargs) -> None: - RayGPUExecutor.__init__(self, *args, **kwargs) - self.last_inference_latency = 0 - self.migration_bandwidth = self.latency_mem.migration_bandwidth - # TODO(ZeldaHuang): add swap bandwidth - - self.cache_block_size = get_cache_block_size( - self.cache_config.block_size, self.model_config, self.parallel_config) - self.cache_block_size /= GiB_bytes - self.sim_cache_config = SimCacheConfig(self.cache_config.gpu_memory_utilization, - self.cache_config.block_size, - self.scheduler_config.max_num_batched_tokens) - - def _init_executor(self) -> None: - pass - - def determine_num_available_blocks(self) -> Tuple[int, int]: - num_gpu_blocks = self.latency_mem.cache_dict.get(self.sim_cache_config, 880) - num_cpu_blocks = 2048 - return (num_gpu_blocks, num_cpu_blocks) - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - logger.info("# GPU blocks: {}, # CPU blocks: {}".format(num_gpu_blocks, num_cpu_blocks)) - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - prefill_seq_len = 0 - decode_seq_len = 0 - decode_bs = 0 - for meta_data in execute_model_req.seq_group_metadata_list: - if meta_data.is_prompt: - prefill_seq_len += meta_data.token_chunk_size - else: - decode_bs += meta_data.token_chunk_size - decode_seq_len += list(meta_data.seq_data.values())[0].get_len() - decode_bs = _pad_to_alignment(decode_bs, 8) - prefill_seq_len = _pad_to_alignment(prefill_seq_len, 8) - latency = 0 - if prefill_seq_len: - latency += self.latency_mem.prefill_latency[prefill_seq_len][0] if prefill_seq_len in self.latency_mem.prefill_latency \ - else model_prefill(prefill_seq_len, *self.latency_mem.prefill_model_params) - if decode_bs: - decode_meta_data = (decode_bs, decode_seq_len) - latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \ - else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params) - await asyncio.sleep(latency/1000) - sampler_outputs = [] - for meta_data in execute_model_req.seq_group_metadata_list: - samples = [] - for seq_id in meta_data.seq_data.keys(): - dummy_sample_output = SequenceOutput(seq_id, 20, {20: Logprob(1.0)}) - samples.append(dummy_sample_output) - if samples: - output = CompletionSequenceGroupOutput(samples, None) - sampler_outputs.append(output) - return [SamplerOutput(outputs=sampler_outputs)] - - async def send_blocks(self, blocks_len) -> None: - migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth - await asyncio.sleep(migration_latency) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 49d1375e..a213634a 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -109,7 +109,7 @@ def from_engine_args( # Initialize the cluster and specify the executor class. # pylint: disable=import-outside-toplevel if latency_mem is not None: - from llumnix.backends.vllm.executor import SimGPUExecutor + from llumnix.backends.vllm.sim_executor import SimGPUExecutor executor_class = SimGPUExecutor executor_class.latency_mem = latency_mem elif engine_config.parallel_config.use_ray: @@ -275,6 +275,7 @@ def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], s # TODO(s5u13b): Reduce the across-actor overhead. self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict) + class BackendVLLM(BackendInterface): def __init__( self, diff --git a/llumnix/backends/vllm/sim_executor.py b/llumnix/backends/vllm/sim_executor.py new file mode 100644 index 00000000..dccb224b --- /dev/null +++ b/llumnix/backends/vllm/sim_executor.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from typing import List, Tuple + +from vllm.executor.ray_gpu_executor import RayGPUExecutor + +from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest +from vllm.config import _GB + +from llumnix.logging.logger import init_logger +from llumnix.backends.vllm.utils import get_cache_block_size +from llumnix.backends.profiling import LatencyMemData, SimCacheConfig, model_prefill, model_decode, _pad_to_alignment + +logger = init_logger(__name__) + + +class SimGPUExecutor(RayGPUExecutor): + latency_mem: LatencyMemData = None + def __init__(self, *args, **kwargs) -> None: + RayGPUExecutor.__init__(self, *args, **kwargs) + self.last_inference_latency = 0 + self.migration_bandwidth = self.latency_mem.migration_bandwidth + # TODO(ZeldaHuang): add swap bandwidth + + self.cache_block_size = get_cache_block_size( + self.cache_config.block_size, self.model_config, self.parallel_config) + self.cache_block_size /= _GB + self.sim_cache_config = SimCacheConfig(self.cache_config.gpu_memory_utilization, + self.cache_config.block_size, + self.scheduler_config.max_num_batched_tokens) + + def _init_executor(self) -> None: + pass + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_gpu_blocks = self.latency_mem.cache_dict.get(self.sim_cache_config, 880) + num_cpu_blocks = 2048 + return (num_gpu_blocks, num_cpu_blocks) + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# GPU blocks: {}, # CPU blocks: {}".format(num_gpu_blocks, num_cpu_blocks)) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + prefill_seq_len = 0 + decode_seq_len = 0 + decode_bs = 0 + for meta_data in execute_model_req.seq_group_metadata_list: + if meta_data.is_prompt: + prefill_seq_len += meta_data.token_chunk_size + else: + decode_bs += meta_data.token_chunk_size + decode_seq_len += list(meta_data.seq_data.values())[0].get_len() + decode_bs = _pad_to_alignment(decode_bs, 8) + prefill_seq_len = _pad_to_alignment(prefill_seq_len, 8) + latency = 0 + if prefill_seq_len: + latency += self.latency_mem.prefill_latency[prefill_seq_len][0] if prefill_seq_len in self.latency_mem.prefill_latency \ + else model_prefill(prefill_seq_len, *self.latency_mem.prefill_model_params) + if decode_bs: + decode_meta_data = (decode_bs, decode_seq_len) + latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \ + else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params) + await asyncio.sleep(latency/1000) + sampler_outputs = [] + for meta_data in execute_model_req.seq_group_metadata_list: + samples = [] + for seq_id in meta_data.seq_data.keys(): + dummy_sample_output = SequenceOutput(seq_id, 20, {20: Logprob(1.0)}) + samples.append(dummy_sample_output) + if samples: + output = SequenceGroupOutput(samples, None) + sampler_outputs.append(output) + return [SamplerOutput(outputs=sampler_outputs)] + + async def send_blocks(self, blocks_len) -> None: + migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth + await asyncio.sleep(migration_latency) diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/sim_llm_engine.py similarity index 99% rename from llumnix/backends/vllm/simulator.py rename to llumnix/backends/vllm/sim_llm_engine.py index 8264dbb7..5d372b17 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/sim_llm_engine.py @@ -27,6 +27,7 @@ logger = init_logger(__name__) + class BackendSimVLLM(BackendVLLM): # pylint: disable=super-init-not-called def __init__( diff --git a/tests/unit_test/backends/vllm/test_llm_engine.py b/tests/unit_test/backends/vllm/test_llm_engine.py index ba9b5225..a09f828b 100644 --- a/tests/unit_test/backends/vllm/test_llm_engine.py +++ b/tests/unit_test/backends/vllm/test_llm_engine.py @@ -24,7 +24,8 @@ from vllm.utils import Counter from llumnix.backends.vllm.llm_engine import LLMEngineLlumnix -from llumnix.backends.vllm.executor import LlumnixRayGPUExecutor, SimGPUExecutor +from llumnix.backends.vllm.executor import LlumnixRayGPUExecutor +from llumnix.backends.vllm.sim_executor import SimGPUExecutor from llumnix.backends.profiling import LatencyMemData from llumnix.backends.vllm.sequence import LlumnixRequest from llumnix.queue.queue_type import QueueType diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 19afce1c..5867dee4 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -7,8 +7,8 @@ from vllm.utils import random_uuid from vllm.sequence import ExecuteModelRequest -from llumnix.backends.vllm.executor import SimGPUExecutor -from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.backends.vllm.sim_executor import SimGPUExecutor +from llumnix.backends.vllm.sim_llm_engine import BackendSimVLLM from llumnix.backends.profiling import LatencyMemData from llumnix.internal_config import MigrationConfig from llumnix.queue.queue_type import QueueType diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 009947be..c1c3d9a6 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -28,8 +28,8 @@ from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType -from llumnix.instance_info import InstanceType -from llumnix.backends.vllm.simulator import BackendSimVLLM +from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.backends.vllm.sim_llm_engine import BackendSimVLLM from llumnix.backends.backend_interface import BackendType from llumnix.backends.profiling import LatencyMemData from llumnix.entrypoints.utils import LaunchMode @@ -226,8 +226,8 @@ def test_init_instances(ray_env, manager): def test_init_instances_sim(ray_env, manager): manager.profiling_result_file_path="//" # pylint: disable=import-outside-toplevel - import llumnix.backends.vllm.simulator - llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim + import llumnix.backends.vllm.sim_llm_engine + llumnix.backends.vllm.sim_llm_engine.BackendSimVLLM = MockBackendSim engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, InstanceArgs(), engine_args)) num_instances = len(instances) From 6783a2ed4272da2822134e6da9ae6c276e8eec2a Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 01:45:28 +0000 Subject: [PATCH 28/59] Minors --- llumnix/backends/vllm/sim_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llumnix/backends/vllm/sim_executor.py b/llumnix/backends/vllm/sim_executor.py index dccb224b..98180e79 100644 --- a/llumnix/backends/vllm/sim_executor.py +++ b/llumnix/backends/vllm/sim_executor.py @@ -29,6 +29,7 @@ class SimGPUExecutor(RayGPUExecutor): latency_mem: LatencyMemData = None + def __init__(self, *args, **kwargs) -> None: RayGPUExecutor.__init__(self, *args, **kwargs) self.last_inference_latency = 0 From 9bac2109448688c91a37896ac2f32b6361545597 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 06:17:12 +0000 Subject: [PATCH 29/59] Assert enable_scaling --- llumnix/arg_utils.py | 2 ++ llumnix/config/default.py | 4 ++-- llumnix/global_scheduler/scaling_scheduler.py | 5 +++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 39b39d79..723529cb 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -205,6 +205,8 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser): assert not args.enable_port_offset_store or args.enable_port_increment, \ "Set enable_port_increment when enable_port_offset_store" + assert not args.enable_scaling, "Proactive scaling is deprecated now, all scaling related args will not take effects." + @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument('--initial-instances', diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 7576c257..85ddddbd 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -96,9 +96,9 @@ # Instance scaling load metric _C.MANAGER.SCALING_LOAD_METRIC = 'remaining_steps' # Minimum number of instances -_C.MANAGER.MIN_INSTANCES = 1 +_C.MANAGER.MIN_INSTANCES = -1 # Maximum number of instances -_C.MANAGER.MAX_INSTANCES = 1 +_C.MANAGER.MAX_INSTANCES = -1 # Interval time to check scaling _C.MANAGER.SCALING_INTERVAL = 10 # Scaling policy diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index ff966617..7abe8a45 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -90,6 +90,7 @@ def get_empty_instance_info(self) -> InstanceInfo: dummy_intance_info.num_available_gpu_blocks_waiting = np.inf return dummy_intance_info + class ScalePolicy(ABC): def __init__(self, scaling_load_metric: str) -> None: self.scaling_load_calculator = ScalingLoadComputation(scaling_load_metric) @@ -121,6 +122,8 @@ def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: return max([i.instance_load_dispatch_scale for i in instance_infos]) + + class MinLoad(ScalePolicy): def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: return min([i.instance_load_dispatch_scale for i in instance_infos]) @@ -128,6 +131,7 @@ def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: return min([i.instance_load_dispatch_scale for i in instance_infos]) + class AvgLoad(ScalePolicy): def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: return self.compute_load_metric_avg(instance_infos) @@ -152,6 +156,7 @@ def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks return self.scaling_load_calculator.compute_instance_load(tot_instance_info) + class ScalePolicyFactory: _POLICY_REGISTRY = { 'max_load': MaxLoad, From 4a2fd267b21eb76ca7c43bcf3ef42584ddf9e98b Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 06:19:11 +0000 Subject: [PATCH 30/59] Minors --- llumnix/arg_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 723529cb..dfb80ab8 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -205,7 +205,7 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser): assert not args.enable_port_offset_store or args.enable_port_increment, \ "Set enable_port_increment when enable_port_offset_store" - assert not args.enable_scaling, "Proactive scaling is deprecated now, all scaling related args will not take effects." + assert not args.enable_scaling, "Proactive auto-scaling is deprecated now, all auto-scaling related args will not take effects." @staticmethod def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: From 1c1097c71b17bcaa2c1c100eb549956e897ce898 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 06:38:26 +0000 Subject: [PATCH 31/59] Set max_instances for auto scale up --- llumnix/manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llumnix/manager.py b/llumnix/manager.py index 4b889bd5..a214cf16 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -318,6 +318,9 @@ async def _auto_scale_up_loop(self, interval: float) -> None: if new_pg is not None and instance_id == new_instance_id: continue self.scale_down(instance_id) + alive_pg_states = list_placement_groups(filters=[("state", "!=", "REMOVED")]) + if self.max_instances != -1 and len(alive_pg_states) >= self.max_instances: + time.sleep(interval) if new_pg is None: new_instance_id = random_uuid() new_pg = self.launcher.init_placement_group(get_placement_group_name(new_instance_id), self.engine_args, self.backend_type, From 7ef370de53cd0168d44b37807c1a16736d404d56 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 07:15:52 +0000 Subject: [PATCH 32/59] Add retry bind address for zmq server --- llumnix/constants.py | 14 +++++++++----- llumnix/entrypoints/setup.py | 6 +++--- llumnix/entrypoints/utils.py | 22 +++++++++++----------- llumnix/manager.py | 4 ++-- llumnix/queue/utils.py | 3 +-- llumnix/queue/zmq_server.py | 24 ++++++++++++++++++++---- tests/unit_test/queue/test_zmq.py | 6 +++--- 7 files changed, 49 insertions(+), 30 deletions(-) diff --git a/llumnix/constants.py b/llumnix/constants.py index 9b13e74e..b6294ad3 100644 --- a/llumnix/constants.py +++ b/llumnix/constants.py @@ -13,7 +13,7 @@ # llumnix/manager.py CLEAR_REQUEST_INSTANCE_INTERVAL: float = 1000.0 -NO_INSTANCE_RETRY_INTERVAL: float = 1.0 +NO_INSTANCE_RETRY_GENERATE_INTERVAL: float = 1.0 WAIT_ALL_MIGRATIONS_DONE_INTERVAL: float = 0.1 AUTO_SCALE_UP_INTERVAL: float = 1.0 WAIT_PLACEMENT_GROUP_TIMEOUT: float = 5.0 @@ -25,7 +25,7 @@ DISPATCH_LOG_FREQUENCY = 100 # llumnix/entrypoints/setup.py -MAX_RAY_RESTARTS: int = 10 +MAX_RAY_RESTART_TIMES: int = 10 RAY_RESTART_INTERVAL: float = 10.0 # llumnix/entrypoints/vllm/client.py, llumnix/entrypoints/bladellm/client.py @@ -40,11 +40,15 @@ # llumnix/backends/vllm/llm_engine.py NO_OUTPUTS_STEP_INTERVAL: float = 0.01 -# llumnix/queue/zmq_utils.py +# llumnix/queue/zmq_client.py RPC_GET_DATA_TIMEOUT_MS: int = 5000 + +# llumnix/queue/zmq_server.py RPC_SOCKET_LIMIT_CUTOFF: int = 2000 RPC_ZMQ_HWM: int = 0 +RETRY_BIND_ADDRESS_INTERVAL: float = 10.0 +MAX_BIND_ADDRESS_RETRY_TIMES: int = 10 # llumnix/entrypoints/utils.py -MAX_TASK_RETRIES: int = 10 -RETRIES_INTERVAL: float = 5.0 +MAX_MANAGER_RETRY_TIMES: int = 10 +RETRY_MANAGER_INTERVAL: float = 5.0 diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index ea504275..cc3126c5 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -30,7 +30,7 @@ retry_manager_method_sync) from llumnix.backends.backend_interface import BackendType from llumnix.queue.queue_server_base import QueueServerBase -from llumnix.constants import MAX_RAY_RESTARTS, RAY_RESTART_INTERVAL +from llumnix.constants import MAX_RAY_RESTART_TIMES, RAY_RESTART_INTERVAL logger = init_logger(__name__) @@ -59,13 +59,13 @@ def launch_ray_cluster(port: int) -> subprocess.CompletedProcess: sys.exit(1) else: ray_start_command = f"ray start --address={head_node_ip}:{port} --node-ip-address={node_ip_address}" - for attempt in range(MAX_RAY_RESTARTS): + for attempt in range(MAX_RAY_RESTART_TIMES): try: # wait about 2 mins by default result = subprocess.run(['ray', 'start', f'--address={head_node_ip}:{port}'], check=True, text=True, capture_output=True) break except subprocess.CalledProcessError as e: - if attempt < MAX_RAY_RESTARTS: + if attempt < MAX_RAY_RESTART_TIMES: logger.warning("Execute '{}' repeatedly until the head node starts.".format(ray_start_command)) time.sleep(RAY_RESTART_INTERVAL) else: diff --git a/llumnix/entrypoints/utils.py b/llumnix/entrypoints/utils.py index 936feeb6..6f55a274 100644 --- a/llumnix/entrypoints/utils.py +++ b/llumnix/entrypoints/utils.py @@ -7,7 +7,7 @@ import ray from llumnix.logging.logger import init_logger -from llumnix.constants import MAX_TASK_RETRIES, RETRIES_INTERVAL +from llumnix.constants import MAX_MANAGER_RETRY_TIMES, RETRY_MANAGER_INTERVAL logger = init_logger(__name__) @@ -48,29 +48,29 @@ def is_gpu_available() -> bool: return False def retry_manager_method_sync(ray_call, method_name, *args, **kwargs): - for attempt in range(MAX_TASK_RETRIES): + for attempt in range(MAX_MANAGER_RETRY_TIMES): try: ret = ray.get(ray_call(*args, **kwargs)) break except ray.exceptions.RayActorError: - if attempt < MAX_TASK_RETRIES - 1: - logger.warning("Manager is unavailable, sleep {}s, and retry {} again.".format(RETRIES_INTERVAL, method_name)) - time.sleep(RETRIES_INTERVAL) + if attempt < MAX_MANAGER_RETRY_TIMES - 1: + logger.warning("Manager is unavailable, sleep {}s, and retry {} again.".format(RETRY_MANAGER_INTERVAL, method_name)) + time.sleep(RETRY_MANAGER_INTERVAL) else: - logger.error("Manager is still unavailable after {} times retries.".format(MAX_TASK_RETRIES)) + logger.error("Manager is still unavailable after {} times retries.".format(MAX_MANAGER_RETRY_TIMES)) raise return ret async def retry_manager_method_async(ray_call, method_name, *args, **kwargs): - for attempt in range(MAX_TASK_RETRIES): + for attempt in range(MAX_MANAGER_RETRY_TIMES): try: ret = await ray_call(*args, **kwargs) break except ray.exceptions.RayActorError: - if attempt < MAX_TASK_RETRIES - 1: - logger.warning("Manager is unavailable, sleep {}s, and retry {} again.".format(RETRIES_INTERVAL, method_name)) - await asyncio.sleep(RETRIES_INTERVAL) + if attempt < MAX_MANAGER_RETRY_TIMES - 1: + logger.warning("Manager is unavailable, sleep {}s, and retry {} again.".format(RETRY_MANAGER_INTERVAL, method_name)) + await asyncio.sleep(RETRY_MANAGER_INTERVAL) else: - logger.error("Manager is still unavailable after {} times retries.".format(MAX_TASK_RETRIES)) + logger.error("Manager is still unavailable after {} times retries.".format(MAX_MANAGER_RETRY_TIMES)) raise return ret diff --git a/llumnix/manager.py b/llumnix/manager.py index a214cf16..193499e9 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -149,8 +149,8 @@ def __init__(self, async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None: while self.num_instances == 0: logger.warning("No instance available now, sleep {}s, " - "and regenerate request {}.".format(NO_INSTANCE_RETRY_INTERVAL, request_id)) - await asyncio.sleep(NO_INSTANCE_RETRY_INTERVAL) + "and regenerate request {}.".format(NO_INSTANCE_RETRY_GENERATE_INTERVAL, request_id)) + await asyncio.sleep(NO_INSTANCE_RETRY_GENERATE_INTERVAL) instance_id, request_expected_steps = self.global_scheduler.dispatch() try: diff --git a/llumnix/queue/utils.py b/llumnix/queue/utils.py index 77c6e07e..c9ac333a 100644 --- a/llumnix/queue/utils.py +++ b/llumnix/queue/utils.py @@ -27,8 +27,7 @@ def init_request_output_queue_server(zmq_ip: str, zmq_port: int, queue_type: QueueType) -> QueueServerBase: output_queue_server: QueueServerBase = None if queue_type == QueueType.ZMQ: - rpc_path = get_open_zmq_ipc_path(zmq_ip, zmq_port) - output_queue_server = ZmqServer(rpc_path) + output_queue_server = ZmqServer(zmq_ip, zmq_port) else: output_queue_server = RayQueueServer() return output_queue_server diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index 7180718f..aaa0b43c 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -19,11 +19,13 @@ import zmq import zmq.asyncio import cloudpickle +import zmq.error from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) from llumnix.logging.logger import init_logger -from llumnix.constants import RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM +from llumnix.constants import (RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RETRY_BIND_ADDRESS_INTERVAL, + MAX_BIND_ADDRESS_RETRY_TIMES) from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -37,7 +39,9 @@ class Full(Exception): class ZmqServer: - def __init__(self, rpc_path: str, maxsize=0): + def __init__(self, ip: str, port: int, maxsize=0): + rpc_path = get_open_zmq_ipc_path(ip, port) + self.context = zmq.asyncio.Context() # Maximum number of sockets that can be opened (typically 65536). @@ -55,8 +59,20 @@ def __init__(self, rpc_path: str, maxsize=0): self.socket = self.context.socket(zmq.constants.ROUTER) self.socket.set_hwm(RPC_ZMQ_HWM) - self.socket.bind(rpc_path) - logger.info("QueueServer's socket bind to: {}".format(rpc_path)) + + for attempt in range(MAX_BIND_ADDRESS_RETRY_TIMES): + try: + self.socket.bind(rpc_path) + logger.info("QueueServer's socket bind to: {}".format(rpc_path)) + break + except zmq.error.ZMQError as e: + logger.warning("QueueServer's socket bind to {} failed, exception: {}".format(rpc_path, e)) + if attempt < MAX_BIND_ADDRESS_RETRY_TIMES - 1: + logger.warning("{} already in use, sleep {}s, and retry bind to it again.".format(rpc_path, RETRY_BIND_ADDRESS_INTERVAL)) + time.sleep(RETRY_BIND_ADDRESS_INTERVAL) + else: + logger.error("{} still in use after {} times retries.".format(rpc_path, MAX_BIND_ADDRESS_RETRY_TIMES)) + raise self.maxsize = maxsize self.queue = asyncio.Queue(maxsize) diff --git a/tests/unit_test/queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py index 3ce64212..de14101a 100644 --- a/tests/unit_test/queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -29,7 +29,8 @@ @ray.remote(num_cpus=1) class Server: - def __init__(self, rpc_path): + def __init__(self, ip, port): + rpc_path = get_open_zmq_ipc_path(ip, port) self.server = ZmqServer(rpc_path) asyncio.create_task(self.server.run_server_loop()) request_output_queue = self.server @@ -77,10 +78,9 @@ def timeout_handler(signum, frame): raise TimeoutException("Function call timed out") async def benchmark_queue(qps, ip=None, port=None): - rpc_path = get_open_zmq_ipc_path(ip, port) rpc_client = ZmqClient() request_output_queue = rpc_client - server = Server.remote(rpc_path) + server = Server.remote(ip, port) server_id = random_uuid() server_info = ServerInfo(server_id, 'zmq', None, ip, port) await rpc_client.wait_for_server_rpc(server_info) From 7856ca00668302c0b24ae620331ea79c2c887340 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 07:19:37 +0000 Subject: [PATCH 33/59] Fix lint --- llumnix/queue/utils.py | 1 - llumnix/queue/zmq_server.py | 2 +- tests/unit_test/queue/test_zmq.py | 4 +--- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/llumnix/queue/utils.py b/llumnix/queue/utils.py index c9ac333a..69931276 100644 --- a/llumnix/queue/utils.py +++ b/llumnix/queue/utils.py @@ -17,7 +17,6 @@ from llumnix.queue.ray_queue_server import RayQueueServer from llumnix.queue.zmq_client import ZmqClient from llumnix.queue.ray_queue_client import RayQueueClient -from llumnix.queue.zmq_utils import get_open_zmq_ipc_path from llumnix.queue.queue_type import QueueType from llumnix.logging.logger import init_logger diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index aaa0b43c..345238c8 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -18,8 +18,8 @@ import zmq import zmq.asyncio -import cloudpickle import zmq.error +import cloudpickle from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) diff --git a/tests/unit_test/queue/test_zmq.py b/tests/unit_test/queue/test_zmq.py index de14101a..37e5cbbc 100644 --- a/tests/unit_test/queue/test_zmq.py +++ b/tests/unit_test/queue/test_zmq.py @@ -20,7 +20,6 @@ from llumnix.queue.zmq_server import ZmqServer from llumnix.queue.zmq_client import ZmqClient -from llumnix.queue.utils import get_open_zmq_ipc_path from llumnix.utils import random_uuid from llumnix.server_info import ServerInfo @@ -30,8 +29,7 @@ @ray.remote(num_cpus=1) class Server: def __init__(self, ip, port): - rpc_path = get_open_zmq_ipc_path(ip, port) - self.server = ZmqServer(rpc_path) + self.server = ZmqServer(ip, port) asyncio.create_task(self.server.run_server_loop()) request_output_queue = self.server self.stop_signal = asyncio.Event() From f6bd44c2a8223f7ad5ebf205ad654726b20834fb Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 22 Jan 2025 07:20:49 +0000 Subject: [PATCH 34/59] Fix unit test --- llumnix/queue/zmq_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index 345238c8..7756452a 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -22,7 +22,8 @@ import cloudpickle from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPCPutNoWaitQueueRequest, - RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest) + RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest, + get_open_zmq_ipc_path) from llumnix.logging.logger import init_logger from llumnix.constants import (RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RETRY_BIND_ADDRESS_INTERVAL, MAX_BIND_ADDRESS_RETRY_TIMES) From bfc7b54083565eb9f65a002f8615f476d931b410 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 23 Jan 2025 02:29:23 +0000 Subject: [PATCH 35/59] Refine dispatch scheduler implementation --- llumnix/global_scheduler/dispatch_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index a4aab518..110c8444 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -63,7 +63,7 @@ class DispatchPolicy(ABC): @abstractmethod def dispatch(self, instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: + available_instance_infos: List[InstanceInfo]) -> int: pass # Dispatch all requests to a single instance, used only for testing From e3ace2d96b338ed7359a54eb499d6a991735c0ab Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 23 Jan 2025 07:40:02 +0000 Subject: [PATCH 36/59] Support power-of-k-choice for dispatch --- docs/Arguments.md | 5 + llumnix/arg_utils.py | 9 ++ llumnix/config/default.py | 2 + llumnix/global_scheduler/dispatch_policy.py | 106 ++++++++++++++++++ .../global_scheduler/dispatch_scheduler.py | 87 +------------- llumnix/global_scheduler/global_scheduler.py | 5 +- llumnix/global_scheduler/migration_filter.py | 6 + llumnix/global_scheduler/migration_policy.py | 4 + llumnix/global_scheduler/scaling_policy.py | 86 ++++++++++++++ llumnix/global_scheduler/scaling_scheduler.py | 80 +------------ llumnix/internal_config.py | 5 +- .../test_dispatch_scheduler.py | 37 +++++- .../global_scheduler/test_global_scheduler.py | 2 +- 13 files changed, 266 insertions(+), 168 deletions(-) create mode 100644 llumnix/global_scheduler/dispatch_policy.py create mode 100644 llumnix/global_scheduler/scaling_policy.py diff --git a/docs/Arguments.md b/docs/Arguments.md index 6c5c6e89..d1c7b8aa 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -25,6 +25,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--scaling-load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] [--dispatch-policy {balanced,load,queue,rr}] + [--power-of-k-choice POWER_OF_K_CHOICE] [--enable-migration] [--enable-defrag] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] @@ -139,6 +140,10 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Possible choices: balanced, load, queue, rr - Default: "load" +`--power-of-k-choice` +- Number of candidate instances for dispatch policy +- Default: 1 + `--enable-migration` - Enable migrate requests between instances. diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index dfb80ab8..ef47e847 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -119,6 +119,7 @@ class ManagerArgs: polling_interval: float = None dispatch_policy: str = None scaling_load_metric: str = None + power_of_k_choice: int = None enable_migration: bool = None pair_migration_frequency: int = None @@ -174,6 +175,7 @@ def create_global_scheduler_config(self, is_group_kind_migration_backend) -> Tup # Create the GlobalScheduler Configuration. global_scheduler_config = GlobalSchedulerConfig(self.initial_instances, self.dispatch_policy, + self.power_of_k_choice, self.pair_migration_policy, self.migrate_out_threshold, self.scaling_policy, @@ -228,6 +230,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' '* "flood" dispatch request to the instance with maximum requests dispatched.\n' '* "rr" dispatch requests with round-robin policy.\n') + parser.add_argument('--power-of-k-choice', + type=int, + help='number of candidate instances for dispatch policy.\n\n' + 'The candidate instances are first selected according to the load' + '(including factors such as load, queue size, etc.) based on the dispatch policy,' + 'and then one of them is randomly chosen to receive the request for better load balancing.') + parser.add_argument('--enable-migration', action='store_true', help='enable migrate requests between instances') diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 85ddddbd..752b9d99 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -79,6 +79,8 @@ # -------------------------- DISPATCH CONFIGURATION --------------------------- # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' +# Number of candidate instances for dispatch policy +_C.MANAGER.POWER_OF_K_CHOICE = 1 # -------------------------- MIGRATION CONFIGURATION -------------------------- # Enable migrate requests between instances diff --git a/llumnix/global_scheduler/dispatch_policy.py b/llumnix/global_scheduler/dispatch_policy.py new file mode 100644 index 00000000..c16a8692 --- /dev/null +++ b/llumnix/global_scheduler/dispatch_policy.py @@ -0,0 +1,106 @@ +from typing import Dict, List +from abc import ABC, abstractmethod +import random + +from llumnix.logging.logger import init_logger +from llumnix.instance_info import InstanceInfo + + +logger = init_logger(__name__) + + +def sort_instance_infos(available_instance_infos: List[InstanceInfo], + key_attr: str, + descending: bool = False) -> None: + return sorted( + available_instance_infos, + key=lambda instance_info: getattr(instance_info, key_attr), + reverse=descending + ) + +def random_choice_from_top_k(sorted_instance_infos: List[InstanceInfo], + power_of_k_choice: int): + k = min(power_of_k_choice, len(sorted_instance_infos)) + top_k_instance_infos = sorted_instance_infos[:k] + return random.choice(top_k_instance_infos) + + +class DispatchPolicy(ABC): + @abstractmethod + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> int: + pass + + +# Dispatch all requests to a single instance, used only for testing +class Flood(DispatchPolicy): + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> str: + instance_id = max(instance_num_requests, key=instance_num_requests.get) + return instance_id + + +class Balanced(DispatchPolicy): + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> str: + # dispatch request according to the number of requests dispatched to instance by manager + instance_id = min(instance_num_requests, key=instance_num_requests.get) + return instance_id + + +class Load(DispatchPolicy): + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> str: + sorted_instance_infos = sort_instance_infos(available_instance_infos, 'dispatch_load_metric') + instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, power_of_k_choice) + instance_id = instance_info_chosen.instance_id + logger.info("dispatch to {}, load: {}".format(instance_id, instance_info_chosen.instance_load_dispatch_scale)) + return instance_id + + +class Queue(DispatchPolicy): + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> str: + sorted_instance_infos = sort_instance_infos(available_instance_infos, 'num_waiting_requests') + instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, power_of_k_choice) + instance_id = instance_info_chosen.instance_id + logger.info("dispatch to {}, queue size: {}".format(instance_id, instance_info_chosen.num_waiting_requests)) + return instance_id + + +class RoundRobin(DispatchPolicy): + prev_instance_idx: int = -1 + + def dispatch(self, + instance_num_requests: Dict[str, int], + available_instance_infos: List[InstanceInfo], + power_of_k_choice: int) -> str: + all_instance_ids = sorted(instance_num_requests.keys()) + cur_instance_idx = (self.prev_instance_idx + 1) % len(all_instance_ids) + target_instance_id = all_instance_ids[cur_instance_idx] + self.prev_instance_idx = cur_instance_idx + return target_instance_id + + +class DispatchPolicyFactory: + _POLICY_REGISTRY = { + 'flood': Flood, + 'balanced': Balanced, + 'load': Load, + 'queue': Queue, + 'rr': RoundRobin, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> DispatchPolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 110c8444..e4c8245d 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -12,8 +12,6 @@ # limitations under the License. from typing import Dict, List, Set -from abc import ABC, abstractmethod -import random from llumnix.logging.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceType @@ -24,8 +22,11 @@ class DispatchScheduler: - def __init__(self, dispatch_policy: str,) -> None: + def __init__(self, + dispatch_policy: str, + power_of_k_choice: int) -> None: self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy) + self.power_of_k_choice = power_of_k_choice self.available_dispatch_instance_set: Set[str] = set() self.instance_info: Dict[str, InstanceInfo] = {} # statistics @@ -35,7 +36,8 @@ def __init__(self, dispatch_policy: str,) -> None: def dispatch(self) -> str: self.total_num_requests += 1 dispatch_instance_id = self.dispatch_policy.dispatch(self.instance_num_requests, - self.instance_info.values()) + self.instance_info.values(), + self.power_of_k_choice) self.instance_num_requests[dispatch_instance_id] += 1 if self.total_num_requests % DISPATCH_LOG_FREQUENCY == 0: logger.info("dispatch scheduler total_dispatched_requests: {}".format(self.total_num_requests)) @@ -58,80 +60,3 @@ def remove_instance(self, instance_id: str) -> None: if instance_id in self.available_dispatch_instance_set: self.available_dispatch_instance_set.remove(instance_id) self.instance_num_requests.pop(instance_id, None) - -class DispatchPolicy(ABC): - @abstractmethod - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> int: - pass - -# Dispatch all requests to a single instance, used only for testing -class Flood(DispatchPolicy): - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: - instance_id = max(instance_num_requests, key=instance_num_requests.get) - return instance_id - -class Balanced(DispatchPolicy): - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: - # dispatch request according to the number of requests dispatched to instance by manager - instance_id = min(instance_num_requests, key=instance_num_requests.get) - return instance_id - -class Load(DispatchPolicy): - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: - sorted_instance_infos = sorted( - available_instance_infos, - key=lambda instance_info: getattr(instance_info, 'dispatch_load_metric'), - ) - instance_id = sorted_instance_infos[0].instance_id - logger.debug("dispatch to {}, load: {}".format(instance_id, sorted_instance_infos[0].dispatch_load_metric)) - return instance_id - -class Queue(DispatchPolicy): - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: - sorted_instance_infos = sorted( - available_instance_infos, - key=lambda instance_info: getattr(instance_info, 'num_waiting_requests'), - ) - min_queue_size = sorted_instance_infos[0].num_waiting_requests - instance_id_list = [] - for instance_info in sorted_instance_infos: - if instance_info.num_waiting_requests == min_queue_size: - instance_id_list.append(instance_info.instance_id) - instance_id = random.choice(instance_id_list) - logger.debug("dispatch to {}, queue size: {}".format(instance_id, sorted_instance_infos[0].num_waiting_requests)) - return instance_id - -class RoundRobin(DispatchPolicy): - next_instance_idx: int = 0 - - def dispatch(self, - instance_num_requests: Dict[str, int], - available_instance_infos: List[InstanceInfo]) -> str: - all_instance_ids = sorted(instance_num_requests.keys()) - assert len(all_instance_ids) > 0 - target_instance_id = all_instance_ids[self.next_instance_idx % len(all_instance_ids)] - self.next_instance_idx += 1 - return target_instance_id - -class DispatchPolicyFactory: - _POLICY_REGISTRY = { - 'flood': Flood, - 'balanced': Balanced, - 'load': Load, - 'queue': Queue, - 'rr': RoundRobin, - } - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> DispatchPolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 548e7fe4..9f049a47 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -34,7 +34,8 @@ def __init__(self, global_scheduler_config: GlobalSchedulerConfig) -> None: self.instance_info: Dict[str, InstanceInfo] = {} # dispatch args - self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy) + self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy, + global_scheduler_config.power_of_k_choice) # migrate args self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy, global_scheduler_config.migrate_out_load_threshold, @@ -44,7 +45,7 @@ def __init__(self, global_scheduler_config: GlobalSchedulerConfig) -> None: global_scheduler_config.scale_down_threshold, global_scheduler_config.scaling_policy, global_scheduler_config.scaling_load_metric, - global_scheduler_config.enable_pd_disagg,) + global_scheduler_config.enable_pd_disagg) def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None: for instance_info in instance_infos: diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py index b2bfd943..3ea3c637 100644 --- a/llumnix/global_scheduler/migration_filter.py +++ b/llumnix/global_scheduler/migration_filter.py @@ -26,6 +26,7 @@ class MigrationFilterConfig: def __init__(self, migrate_out_load_threshold): self.migrate_out_load_threshold: float = migrate_out_load_threshold + # TODO(KuilongCui): A filter might contain other filters; leave this for the future. class MigrationFilterPolicy(ABC): @abstractmethod @@ -36,6 +37,7 @@ def filter_src_condition(self, filter_config, pair_migration_type) -> Callable[[ def filter_dst_condition(self, filter_config, pair_migration_type) -> Callable[[InstanceInfo], bool]: raise NotImplementedError + class MigrationInstanceFilter(ABC): def __init__(self, filter_config: MigrationFilterConfig) -> None: self.filter_config = filter_config @@ -76,6 +78,7 @@ def filter_instances(self, instance_infos: List[InstanceInfo], return filtered_src_instance_infos, filtered_dst_instance_infos + class LoadConstrainedFilter(MigrationFilterPolicy): def filter_src_condition(self, filter_config: MigrationFilterConfig, pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: @@ -87,6 +90,7 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, return lambda instance_info: instance_info.num_killed_requests == 0 \ and instance_info.migration_load_metric < filter_config.migrate_out_load_threshold + class PddFilter(MigrationFilterPolicy): INSTANCE_FILTER_RULES = { PairMigrationConstraints.DECODING_2_DECODING: (InstanceType.DECODE, InstanceType.DECODE), @@ -119,6 +123,7 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, return lambda instance_info: instance_type_filter(instance_info) and policy_filter(instance_info) + class CustomFilter(MigrationFilterPolicy): def __init__(self): super().__init__() @@ -140,6 +145,7 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: return self.dst_filter + class MigrationFilterPolicyFactory: _POLICY_REGISTRY = { 'load': LoadConstrainedFilter, diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py index 63989da7..53427db8 100644 --- a/llumnix/global_scheduler/migration_policy.py +++ b/llumnix/global_scheduler/migration_policy.py @@ -29,6 +29,7 @@ class PairMigrationConstraints(str, Enum): DECODING_2_DECODING = "DECODING_2_DECODING" PREFILL_2_DECODING = "PREFILL_2_DECODING" + class PairMigrationPolicy(ABC): def __init__(self, migrate_out_load_threshold: float) -> None: self.migrate_out_load_threshold = migrate_out_load_threshold @@ -49,6 +50,7 @@ def sort_instance_infos(self, instance_infos: List[InstanceInfo], descending: bo ) return sorted_instance_infos + class Balanced(PairMigrationPolicy): def pair_migration(self, src_instance_infos: List[InstanceInfo], @@ -70,6 +72,7 @@ def pair_migration(self, sorted_dst_instance_infos[i].instance_id)) return migrate_instance_pairs + class DefragConstrained(PairMigrationPolicy): def pair_migration(self, src_instance_infos: List[InstanceInfo], @@ -83,6 +86,7 @@ def pair_migration(self, migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) return migrate_instance_pairs + class PairMigrationPolicyFactory: _POLICY_REGISTRY = { 'balanced': Balanced, diff --git a/llumnix/global_scheduler/scaling_policy.py b/llumnix/global_scheduler/scaling_policy.py new file mode 100644 index 00000000..e02639d7 --- /dev/null +++ b/llumnix/global_scheduler/scaling_policy.py @@ -0,0 +1,86 @@ +from typing import List +from abc import ABC, abstractmethod + +from llumnix.logging.logger import init_logger +from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator + +logger = init_logger(__name__) + + +class ScalePolicy(ABC): + def __init__(self, scaling_load_metric: str) -> None: + self.scaling_load_calculator = ScalingLoadComputation(scaling_load_metric) + + @abstractmethod + def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: + pass + + @abstractmethod + def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: + pass + + def compute_load_metric_avg(self, instance_infos: List[InstanceInfo]) -> float: + tot_instance_info = InstanceInfo() + tot_instance_info.instance_id = -1 + tot_instance_info.step_id = -1 + tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) + tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) + tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks for i in instance_infos]) + tot_instance_info.num_total_gpu_blocks = sum([i.num_total_gpu_blocks for i in instance_infos]) + tot_instance_info.num_watermark_blocks = sum([i.num_watermark_blocks for i in instance_infos]) + tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) + tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks + return self.instance_load_calculator.compute_instance_load(tot_instance_info, action="scale") + + +class MaxLoad(ScalePolicy): + def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: + return max([i.instance_load_dispatch_scale for i in instance_infos]) + + def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: + return max([i.instance_load_dispatch_scale for i in instance_infos]) + + +class MinLoad(ScalePolicy): + def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: + return min([i.instance_load_dispatch_scale for i in instance_infos]) + + def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: + return min([i.instance_load_dispatch_scale for i in instance_infos]) + + +class AvgLoad(ScalePolicy): + def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: + return self.compute_load_metric_avg(instance_infos) + + def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: + num_instances = len(instance_infos) + tot_instance_info = InstanceInfo() + tot_instance_info.instance_id = -1 + tot_instance_info.step_id = -1 + # the average load after scale down the last instance + tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) + tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) + tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks - i.num_total_gpu_blocks + if i.instance_id + 1 == num_instances else i.num_free_gpu_blocks + for i in instance_infos]) + tot_instance_info.num_free_gpu_blocks = max(0, tot_instance_info.num_free_gpu_blocks) + tot_instance_info.num_total_gpu_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_total_gpu_blocks + for i in instance_infos]) + tot_instance_info.num_watermark_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_watermark_blocks + for i in instance_infos]) + tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) + tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks + return self.instance_load_calculator.compute_instance_load(tot_instance_info, action='scale') + + +class ScalePolicyFactory: + _POLICY_REGISTRY = { + 'max_load': MaxLoad, + 'min_load': MinLoad, + 'avg_load': AvgLoad, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> ScalePolicy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index 7abe8a45..e54a674c 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -12,7 +12,7 @@ # limitations under the License. from typing import Dict, List, Tuple, Set -from abc import ABC, abstractmethod +from enum import Enum import numpy as np from llumnix.logging.logger import init_logger @@ -89,81 +89,3 @@ def get_empty_instance_info(self) -> InstanceInfo: dummy_intance_info.num_free_gpu_blocks = np.inf dummy_intance_info.num_available_gpu_blocks_waiting = np.inf return dummy_intance_info - - -class ScalePolicy(ABC): - def __init__(self, scaling_load_metric: str) -> None: - self.scaling_load_calculator = ScalingLoadComputation(scaling_load_metric) - - @abstractmethod - def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: - pass - - @abstractmethod - def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: - pass - - def compute_load_metric_avg(self, instance_infos: List[InstanceInfo]) -> float: - tot_instance_info = InstanceInfo() - tot_instance_info.instance_id = -1 - tot_instance_info.step_id = -1 - tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) - tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) - tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks for i in instance_infos]) - tot_instance_info.num_total_gpu_blocks = sum([i.num_total_gpu_blocks for i in instance_infos]) - tot_instance_info.num_watermark_blocks = sum([i.num_watermark_blocks for i in instance_infos]) - tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) - tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks - return self.scaling_load_calculator.compute_instance_load(tot_instance_info) - -class MaxLoad(ScalePolicy): - def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: - return max([i.instance_load_dispatch_scale for i in instance_infos]) - - def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: - return max([i.instance_load_dispatch_scale for i in instance_infos]) - - -class MinLoad(ScalePolicy): - def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: - return min([i.instance_load_dispatch_scale for i in instance_infos]) - - def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: - return min([i.instance_load_dispatch_scale for i in instance_infos]) - - -class AvgLoad(ScalePolicy): - def compute_load_metric_up(self, instance_infos: List[InstanceInfo]) -> float: - return self.compute_load_metric_avg(instance_infos) - - def compute_load_metric_down(self, instance_infos: List[InstanceInfo]) -> float: - num_instances = len(instance_infos) - tot_instance_info = InstanceInfo() - tot_instance_info.instance_id = -1 - tot_instance_info.step_id = -1 - # the average load after scale down the last instance - tot_instance_info.num_running_requests = sum([i.num_running_requests for i in instance_infos]) - tot_instance_info.num_waiting_requests = sum([i.num_waiting_requests for i in instance_infos]) - tot_instance_info.num_free_gpu_blocks = sum([i.num_free_gpu_blocks - i.num_total_gpu_blocks - if i.instance_id + 1 == num_instances else i.num_free_gpu_blocks - for i in instance_infos]) - tot_instance_info.num_free_gpu_blocks = max(0, tot_instance_info.num_free_gpu_blocks) - tot_instance_info.num_total_gpu_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_total_gpu_blocks - for i in instance_infos]) - tot_instance_info.num_watermark_blocks = sum([0 if i.instance_id + 1 == num_instances else i.num_watermark_blocks - for i in instance_infos]) - tot_instance_info.num_blocks_all_waiting_requests = sum([i.num_blocks_all_waiting_requests for i in instance_infos]) - tot_instance_info.num_available_gpu_blocks = tot_instance_info.num_free_gpu_blocks - tot_instance_info.num_watermark_blocks - return self.scaling_load_calculator.compute_instance_load(tot_instance_info) - - -class ScalePolicyFactory: - _POLICY_REGISTRY = { - 'max_load': MaxLoad, - 'min_load': MinLoad, - 'avg_load': AvgLoad, - } - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> ScalePolicy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 013882c0..03a0a5be 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -42,6 +42,7 @@ def __init__( self, initial_instances: int, dispatch_policy: str, + power_of_k_choice: int, pair_migration_policy: str, migrate_out_threshold: float, scaling_policy: str, @@ -49,9 +50,11 @@ def __init__( scale_up_threshold: float, scale_down_threshold: float, enable_pd_disagg: bool, - is_group_kind_migration_backend: bool,) -> None: + is_group_kind_migration_backend: bool) -> None: self.initial_instances = initial_instances self.dispatch_policy = dispatch_policy + self.power_of_k_choice = power_of_k_choice + self.pair_migration_policy = pair_migration_policy self.migrate_out_load_threshold = migrate_out_threshold diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index e32faf3b..fa96af2e 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -18,10 +18,13 @@ from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler from llumnix.arg_utils import InstanceArgs -INSTANCE_NUM = 4 def test_add_instance_and_remove_instance(): dispatch_scheduler = DispatchScheduler('balanced') +def init_dispatch_scheduler(policy='load'): + instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) + dispatch_scheduler = DispatchScheduler(policy, 1, instance_load_calculator, 2) + return dispatch_scheduler dispatch_scheduler.add_instance('instance_1', InstanceArgs(instance_type="no_constraints")) assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 @@ -72,6 +75,7 @@ def test_dispatch_to_no_constraints_and_prefill(): def test_dispatch_balanced(): num_tests = 100 + instance_num = 4 for _ in range(num_tests): dispatch_scheduler = DispatchScheduler('balanced') instance_num_requests = {} @@ -85,11 +89,12 @@ def test_dispatch_balanced(): def test_dispatch_load(): num_tests = 100 + instance_num = 4 for _ in range(num_tests): dispatch_scheduler = DispatchScheduler('load') instance_num_requests = {} instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.dispatch_load_metric = random.random() @@ -107,11 +112,12 @@ def test_dispatch_load(): def test_dispatch_queue(): num_tests = 100 + instance_num = 4 for _ in range(num_tests): dispatch_scheduler = DispatchScheduler('queue') instance_num_requests = {} instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + for instance_id in [f'instance_{i}' for i in range(1, 4 + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) @@ -129,7 +135,7 @@ def test_dispatch_queue(): def test_dispatch_rr(): instance_num = 7 - dispatch_scheduler = DispatchScheduler("rr") + dispatch_scheduler = DispatchScheduler('rr', 3) instance_num_requests = {} instance_info_dict = {} @@ -148,3 +154,26 @@ def test_dispatch_rr(): instance_id = dispatch_scheduler.dispatch() target_instance_id = idx%instance_num assert instance_id == f'instance_{target_instance_id}' + +def test_dispatch_power_of_k_choice(): + instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) + num_tests = 100 + instance_num = 2 + for power_of_k_choice in [1, 2, 3]: + dispatch_scheduler = DispatchScheduler('load', power_of_k_choice, instance_load_calculator, 2) + instance_num_requests = {} + instance_info_dict = {} + for instance_id in [f'instance_{i}' for i in range(1, 4 + 1)]: + instance_info = InstanceInfo() + instance_info.instance_id = instance_id + instance_info.num_waiting_requests = random.randint(1, 10) + instance_info_dict[instance_id] = instance_info + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 + dispatch_scheduler.instance_num_requests = instance_num_requests + dispatch_scheduler.instance_info = instance_info_dict + instance_id_set = set() + for i in range(num_tests): + instance_id_set.add(dispatch_scheduler.dispatch()) + assert len(instance_id_set) == 2 diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 2c431bf0..18894572 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -24,7 +24,7 @@ def init_global_scheduler(): - global_scheduler_config = GlobalSchedulerConfig(0, 'load', 'defrag_constrained', 3.0, + global_scheduler_config = GlobalSchedulerConfig(0, 'load', 1, 'defrag_constrained', 3.0, 'avg_load', 'remaining_steps', 10, 60, False, False) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler From 87aa5946ea343589ceb044b76f7ca71a48acef9b Mon Sep 17 00:00:00 2001 From: s5u13b Date: Thu, 23 Jan 2025 07:48:17 +0000 Subject: [PATCH 37/59] Fix lint --- tests/unit_test/global_scheduler/test_dispatch_scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index fa96af2e..770cd6e1 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -117,7 +117,7 @@ def test_dispatch_queue(): dispatch_scheduler = DispatchScheduler('queue') instance_num_requests = {} instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, 4 + 1)]: + for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) @@ -163,7 +163,7 @@ def test_dispatch_power_of_k_choice(): dispatch_scheduler = DispatchScheduler('load', power_of_k_choice, instance_load_calculator, 2) instance_num_requests = {} instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(1, 4 + 1)]: + for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: instance_info = InstanceInfo() instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) @@ -174,6 +174,6 @@ def test_dispatch_power_of_k_choice(): dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict instance_id_set = set() - for i in range(num_tests): + for _ in range(num_tests): instance_id_set.add(dispatch_scheduler.dispatch()) assert len(instance_id_set) == 2 From 653a3a0a7e7ce22c589c0b4fbb9affa026b91fc3 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 7 Feb 2025 08:30:56 +0000 Subject: [PATCH 38/59] Fix lint --- llumnix/backends/vllm/executor.py | 3 +- llumnix/backends/vllm/sim_executor.py | 14 ++++----- .../global_scheduler/dispatch_scheduler.py | 2 +- llumnix/global_scheduler/scaling_policy.py | 2 +- llumnix/global_scheduler/scaling_scheduler.py | 3 +- llumnix/manager.py | 2 +- .../test_dispatch_scheduler.py | 30 +++++++++---------- 7 files changed, 24 insertions(+), 32 deletions(-) diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index c8418b59..c388fdde 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -14,14 +14,13 @@ import time from collections import defaultdict -from typing import List, Optional +from typing import Callable, Dict, List, Optional, Tuple, Type import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy # pylint: disable=unused-import from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase -from vllm.model_executor.layers.sampler import SamplerOutput, CompletionSequenceGroupOutput from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper, envs, \ get_ip, get_vllm_instance_id, get_distributed_init_method, get_open_port from vllm.worker.worker_base import WorkerBase diff --git a/llumnix/backends/vllm/sim_executor.py b/llumnix/backends/vllm/sim_executor.py index 98180e79..b5fd4029 100644 --- a/llumnix/backends/vllm/sim_executor.py +++ b/llumnix/backends/vllm/sim_executor.py @@ -17,8 +17,9 @@ from vllm.executor.ray_gpu_executor import RayGPUExecutor -from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest -from vllm.config import _GB +from vllm.sequence import Logprob, SequenceOutput, ExecuteModelRequest +from vllm.model_executor.layers.sampler import SamplerOutput, CompletionSequenceGroupOutput +from vllm.utils import GiB_bytes from llumnix.logging.logger import init_logger from llumnix.backends.vllm.utils import get_cache_block_size @@ -29,7 +30,6 @@ class SimGPUExecutor(RayGPUExecutor): latency_mem: LatencyMemData = None - def __init__(self, *args, **kwargs) -> None: RayGPUExecutor.__init__(self, *args, **kwargs) self.last_inference_latency = 0 @@ -38,7 +38,7 @@ def __init__(self, *args, **kwargs) -> None: self.cache_block_size = get_cache_block_size( self.cache_config.block_size, self.model_config, self.parallel_config) - self.cache_block_size /= _GB + self.cache_block_size /= GiB_bytes self.sim_cache_config = SimCacheConfig(self.cache_config.gpu_memory_utilization, self.cache_config.block_size, self.scheduler_config.max_num_batched_tokens) @@ -85,10 +85,6 @@ async def execute_model_async( dummy_sample_output = SequenceOutput(seq_id, 20, {20: Logprob(1.0)}) samples.append(dummy_sample_output) if samples: - output = SequenceGroupOutput(samples, None) + output = CompletionSequenceGroupOutput(samples, None) sampler_outputs.append(output) return [SamplerOutput(outputs=sampler_outputs)] - - async def send_blocks(self, blocks_len) -> None: - migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth - await asyncio.sleep(migration_latency) diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index e4c8245d..8ad612a2 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Set +from typing import Dict, Set from llumnix.logging.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceType diff --git a/llumnix/global_scheduler/scaling_policy.py b/llumnix/global_scheduler/scaling_policy.py index e02639d7..5048ff73 100644 --- a/llumnix/global_scheduler/scaling_policy.py +++ b/llumnix/global_scheduler/scaling_policy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo, InstanceLoadCalculator +from llumnix.instance_info import InstanceInfo logger = init_logger(__name__) diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index e54a674c..90a74c38 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -12,11 +12,10 @@ # limitations under the License. from typing import Dict, List, Tuple, Set -from enum import Enum import numpy as np from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo, ScalingLoadComputation, InstanceType +from llumnix.instance_info import InstanceInfo, InstanceType from llumnix.arg_utils import InstanceArgs logger = init_logger(__name__) diff --git a/llumnix/manager.py b/llumnix/manager.py index 193499e9..cdbc4e23 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -39,7 +39,7 @@ run_async_func_sync,) from llumnix.entrypoints.utils import LaunchMode from llumnix.queue.queue_type import QueueType -from llumnix.constants import (CLEAR_REQUEST_INSTANCE_INTERVAL, NO_INSTANCE_RETRY_INTERVAL, +from llumnix.constants import (CLEAR_REQUEST_INSTANCE_INTERVAL, NO_INSTANCE_RETRY_GENERATE_INTERVAL, WAIT_ALL_MIGRATIONS_DONE_INTERVAL, AUTO_SCALE_UP_INTERVAL, WAIT_PLACEMENT_GROUP_TIMEOUT, CHECK_DEPLOYMENT_STATES_INTERVAL, WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 770cd6e1..8c335fef 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -19,13 +19,12 @@ from llumnix.arg_utils import InstanceArgs -def test_add_instance_and_remove_instance(): - dispatch_scheduler = DispatchScheduler('balanced') def init_dispatch_scheduler(policy='load'): - instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, 1, instance_load_calculator, 2) + dispatch_scheduler = DispatchScheduler(policy, 1) return dispatch_scheduler +def test_add_instance_and_remove_instance(): + dispatch_scheduler = init_dispatch_scheduler('balanced') dispatch_scheduler.add_instance('instance_1', InstanceArgs(instance_type="no_constraints")) assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.remove_instance('instance_1') @@ -42,10 +41,11 @@ def init_dispatch_scheduler(policy='load'): assert len(dispatch_scheduler.available_dispatch_instance_set) == 0 def test_dispatch_to_no_constraints_and_prefill(): - dispatch_scheduler = DispatchScheduler('rr') + dispatch_scheduler = init_dispatch_scheduler('rr') + instance_num = 4 instance_num_requests = {} instance_info_dict = {} - for instance_id in [f'instance_{i}' for i in range(INSTANCE_NUM)]: + for instance_id in [f'instance_{i}' for i in range(instance_num)]: instance_info = InstanceInfo( instance_id=instance_id, dispatch_load_metric=random.randint(1, 10), @@ -77,9 +77,9 @@ def test_dispatch_balanced(): num_tests = 100 instance_num = 4 for _ in range(num_tests): - dispatch_scheduler = DispatchScheduler('balanced') + dispatch_scheduler = init_dispatch_scheduler('balanced') instance_num_requests = {} - for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: + for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: dispatch_scheduler.available_dispatch_instance_set.add(instance_id) instance_num_requests[instance_id] = random.randint(1, 10) dispatch_scheduler.instance_num_requests = instance_num_requests @@ -91,7 +91,7 @@ def test_dispatch_load(): num_tests = 100 instance_num = 4 for _ in range(num_tests): - dispatch_scheduler = DispatchScheduler('load') + dispatch_scheduler = init_dispatch_scheduler('load') instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: @@ -114,7 +114,7 @@ def test_dispatch_queue(): num_tests = 100 instance_num = 4 for _ in range(num_tests): - dispatch_scheduler = DispatchScheduler('queue') + dispatch_scheduler = init_dispatch_scheduler('queue') instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: @@ -135,7 +135,7 @@ def test_dispatch_queue(): def test_dispatch_rr(): instance_num = 7 - dispatch_scheduler = DispatchScheduler('rr', 3) + dispatch_scheduler = init_dispatch_scheduler('rr') instance_num_requests = {} instance_info_dict = {} @@ -156,11 +156,10 @@ def test_dispatch_rr(): assert instance_id == f'instance_{target_instance_id}' def test_dispatch_power_of_k_choice(): - instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) num_tests = 100 instance_num = 2 for power_of_k_choice in [1, 2, 3]: - dispatch_scheduler = DispatchScheduler('load', power_of_k_choice, instance_load_calculator, 2) + dispatch_scheduler = DispatchScheduler('load', power_of_k_choice) instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: @@ -168,9 +167,8 @@ def test_dispatch_power_of_k_choice(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: - dispatch_scheduler.available_dispatch_instance_set.add(instance_id) - instance_num_requests[instance_id] = 0 + dispatch_scheduler.available_dispatch_instance_set.add(instance_id) + instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict instance_id_set = set() From e9bfeb3ebfe9600f783e76b5d9316a98c8478a9f Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 7 Feb 2025 08:57:29 +0000 Subject: [PATCH 39/59] Fix global scheduler unit test --- llumnix/global_scheduler/dispatch_policy.py | 2 +- llumnix/global_scheduler/dispatch_scheduler.py | 1 + llumnix/global_scheduler/scaling_policy.py | 2 +- llumnix/global_scheduler/scaling_scheduler.py | 1 + tests/unit_test/global_scheduler/test_dispatch_scheduler.py | 2 +- 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/llumnix/global_scheduler/dispatch_policy.py b/llumnix/global_scheduler/dispatch_policy.py index c16a8692..1d594b00 100644 --- a/llumnix/global_scheduler/dispatch_policy.py +++ b/llumnix/global_scheduler/dispatch_policy.py @@ -62,7 +62,7 @@ def dispatch(self, sorted_instance_infos = sort_instance_infos(available_instance_infos, 'dispatch_load_metric') instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, power_of_k_choice) instance_id = instance_info_chosen.instance_id - logger.info("dispatch to {}, load: {}".format(instance_id, instance_info_chosen.instance_load_dispatch_scale)) + logger.info("dispatch to {}, load: {}".format(instance_id, instance_info_chosen.dispatch_load_metric)) return instance_id diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index 8ad612a2..e335f127 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -17,6 +17,7 @@ from llumnix.instance_info import InstanceInfo, InstanceType from llumnix.arg_utils import InstanceArgs from llumnix.constants import DISPATCH_LOG_FREQUENCY +from llumnix.global_scheduler.dispatch_policy import DispatchPolicyFactory logger = init_logger(__name__) diff --git a/llumnix/global_scheduler/scaling_policy.py b/llumnix/global_scheduler/scaling_policy.py index 5048ff73..131956e0 100644 --- a/llumnix/global_scheduler/scaling_policy.py +++ b/llumnix/global_scheduler/scaling_policy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from llumnix.logging.logger import init_logger -from llumnix.instance_info import InstanceInfo +from llumnix.instance_info import InstanceInfo, ScalingLoadComputation logger = init_logger(__name__) diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index 90a74c38..39e52f28 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -17,6 +17,7 @@ from llumnix.logging.logger import init_logger from llumnix.instance_info import InstanceInfo, InstanceType from llumnix.arg_utils import InstanceArgs +from llumnix.global_scheduler.scaling_policy import ScalePolicyFactory logger = init_logger(__name__) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 8c335fef..c21de37c 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -65,7 +65,7 @@ def test_dispatch_to_no_constraints_and_prefill(): assert instance_id not in dispatch_scheduler.available_dispatch_instance_set instance_dispatch_info = defaultdict(int) - for _ in range(INSTANCE_NUM * 2): + for _ in range(instance_num * 2): instance_id = dispatch_scheduler.dispatch() instance_dispatch_info[instance_id] += 1 From 331cd7f9eaae484bdeca8dc8478ea106f4c473cf Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 7 Feb 2025 09:12:09 +0000 Subject: [PATCH 40/59] Fix entrypoints unit test --- llumnix/entrypoints/vllm/api_server.py | 2 +- tests/unit_test/global_scheduler/test_dispatch_scheduler.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index e2aa700d..1a5f8dd2 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -127,7 +127,7 @@ async def generate_benchmark(request: Request) -> Response: final_output = None per_token_latency = [] per_token_latency_breakdown_list = [] - async for request_output in results_generator: + async for request_output in results_generator.generator(): if await request.is_disconnected(): # Abort the request if the client disconnects. await llumnix_client.abort(request_id) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index c21de37c..37ab241a 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -157,7 +157,7 @@ def test_dispatch_rr(): def test_dispatch_power_of_k_choice(): num_tests = 100 - instance_num = 2 + instance_num = 4 for power_of_k_choice in [1, 2, 3]: dispatch_scheduler = DispatchScheduler('load', power_of_k_choice) instance_num_requests = {} @@ -174,4 +174,4 @@ def test_dispatch_power_of_k_choice(): instance_id_set = set() for _ in range(num_tests): instance_id_set.add(dispatch_scheduler.dispatch()) - assert len(instance_id_set) == 2 + assert len(instance_id_set) == power_of_k_choice From 7a65ae65888425567acb7b703ae6f346bd679d54 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Fri, 7 Feb 2025 11:18:40 +0000 Subject: [PATCH 41/59] Squashed commit of the following: commit 48c674b2ec0333388e710679df2d163ce2651d01 Author: s5u13b Date: Fri Feb 7 09:41:05 2025 +0000 Fix lint commit 322862b5ce849b45c827d60e93d7ca792e111cfa Author: s5u13b Date: Fri Feb 7 09:39:31 2025 +0000 Fix entrypoints unit test commit 75af82477e1ca6b6572586870084dde4e4bbd891 Author: s5u13b Date: Fri Feb 7 08:07:26 2025 +0000 Fix lint commit 2818c8d694cf8dc0d58cc357cefead67ef99269b Author: s5u13b Date: Fri Feb 7 08:06:08 2025 +0000 Fix cr commit a172468cf42b93f1e846a14db5b652b8c54c2b3b Author: s5u13b Date: Fri Feb 7 07:01:07 2025 +0000 Fix lint commit 3f863b26a1ffbfce778580edac765b889bb95cd6 Author: s5u13b Date: Fri Feb 7 06:54:18 2025 +0000 Add back timestamp commit 2e53b249daf778d41f3680eb554d7a75185468a2 Author: s5u13b Date: Fri Feb 7 06:45:16 2025 +0000 Fix lint commit eea1a3a5c907eac53d7bd7f6f03dc5821cdc8c55 Author: s5u13b Date: Fri Feb 7 06:37:30 2025 +0000 Add back timestamps commit b4a45ef222f54a3cc68231a0d67c36ed12e0fa2d Author: s5u13b Date: Fri Feb 7 06:21:48 2025 +0000 Remove old filter commit f2df19716888940d76682652b356b8f9f7c1ee2c Author: s5u13b Date: Fri Feb 7 06:12:53 2025 +0000 Add _process_model_outputs back commit a51cf25d48a4e91272f3e37a95d28aa27a8fa4de Author: s5u13b Date: Fri Feb 7 03:46:45 2025 +0000 Fix abort commit 1058ec0409bb836dc22319257f5e93f1c722fc8c Author: s5u13b Date: Fri Feb 7 02:43:14 2025 +0000 Remove blank todo commit 670018e0caa9f7c68d9469179dce13dfa0e21a3e Author: s5u13b Date: Fri Feb 7 02:36:27 2025 +0000 Filter out migrating request commit fa2fc9c92856470a53da2901865e244b7580e110 Author: s5u13b Date: Fri Jan 24 06:25:35 2025 +0000 Remove process_model_outputs request timestamps commit 2a980caa4c642d4dedc789b203cddac24d7a6c63 Author: s5u13b Date: Fri Jan 24 06:10:49 2025 +0000 Fix linting commit 78a1ab48826c281f5ee64805f77e881d67f3a73d Author: s5u13b Date: Fri Jan 24 05:30:15 2025 +0000 Fix request leaking bug of migration commit 774205b05b06e398311c69b99194668a3bc92ed2 Author: s5u13b Date: Fri Jan 24 03:11:08 2025 +0000 Fix commit 814521ee14e9bf4b5b9b7e3a5bb6a2754b948785 Author: s5u13b Date: Fri Jan 24 02:57:20 2025 +0000 Minors commit b3f0688759d2505262a32a0c8bb9d55801b2f0ac Author: s5u13b Date: Fri Jan 24 01:56:09 2025 +0000 Change ci timeout-minutes --- llumnix/backends/vllm/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index a213634a..0263fb4c 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -170,7 +170,7 @@ def _process_model_outputs(self, ctx.output_queue.appendleft((outputs, seq_group_metadata_list, scheduler_outputs, is_async, is_last_step, is_first_step_output, skip)) - set_timestamp(server_info, 'engine_process_model_outputs_timestamp_begin', time.time()) + set_timestamp(server_infos, 'engine_process_model_outputs_timestamp_begin', time.time()) super()._process_model_outputs(ctx, request_id) From 5b6325c87163434eb851ac1d4d6e24618c3c26ec Mon Sep 17 00:00:00 2001 From: s5u13b Date: Sat, 8 Feb 2025 06:55:16 +0000 Subject: [PATCH 42/59] Fix host, num_cpus, serve --- llumnix/backends/profiling.py | 4 ++++ llumnix/backends/vllm/migration_backend.py | 2 ++ llumnix/entrypoints/vllm/api_server_actor.py | 8 ++++---- llumnix/launcher.py | 14 +++++++------- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/llumnix/backends/profiling.py b/llumnix/backends/profiling.py index b79afcc1..32b98c14 100644 --- a/llumnix/backends/profiling.py +++ b/llumnix/backends/profiling.py @@ -35,6 +35,7 @@ def _pad_to_alignment(x, multiple_of): return x + ((-1*x) % multiple_of) + @dataclasses.dataclass class LatencyMemData: # The latency of each stage @@ -69,6 +70,7 @@ def get_prefill_dict_kv(self): def get_decode_dict_kv(self): return map(list, zip(*self.decode_latency.items())) + @dataclasses.dataclass class ProfilingResult: """Store the profiling result of a model.""" @@ -127,6 +129,7 @@ def fit_from_database(self, parallel_config: SimParallelConfig): avg_loss += abs(sim_lat - latency_list[idx]) print(f"decode sim avg_loss={avg_loss/len(latency_list)}") + class ProfilingDatabase: """Store the profiling results of all the models""" def __init__(self, database_filename: str, new_database: bool = False): @@ -198,6 +201,7 @@ def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingData return latency_mem raise ValueError(f'Unsupported simulator backend: {backend_type}') + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 32182a8f..acbac62e 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -40,6 +40,7 @@ def exec_method(self, is_driver_worker, handle, *args, **kwargs): NUMPY_SUPPORTED_DTYPES = [torch.float32, torch.float16] + class RayRpcMigrationBackend(MigrationBackendBase): def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEngine], worker_rank, worker_handle_list, \ scheduling_strategy, is_driver_worker, gpu_cache) -> None: @@ -152,6 +153,7 @@ def try_import_gloo(): except ImportError as e: raise ImportError("Gloo is not installed. Please install it first.") from e + class RayColMigrationBackend(MigrationBackendBase): def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEngine], local_rank, scheduling_strategy, is_driver_worker, gpu_cache) -> None: diff --git a/llumnix/entrypoints/vllm/api_server_actor.py b/llumnix/entrypoints/vllm/api_server_actor.py index 6b3067fa..2574f9e7 100644 --- a/llumnix/entrypoints/vllm/api_server_actor.py +++ b/llumnix/entrypoints/vllm/api_server_actor.py @@ -25,11 +25,11 @@ def __init__(self, server_name: str, entrypoints_args: EntrypointsArgs): logger.info("APIServerActor(job_id={}, worker_id={}, actor_id={}, node_id={}, instance_id={})".format( self.job_id, self.worker_id, self.actor_id, self.node_id, self.instance_id)) self.entrypoints_args = entrypoints_args + self.host = get_ip_address() self.request_output_queue_port = self.entrypoints_args.request_output_queue_port self.request_output_queue_type = QueueType(self.entrypoints_args.request_output_queue_type) - ip = get_ip_address() self.request_output_queue = init_request_output_queue_server( - ip, self.request_output_queue_port, self.request_output_queue_type) + self.host, self.request_output_queue_port, self.request_output_queue_type) def __repr__(self): return f"{self.__class__.__name__}(iid={self.instance_id[:5]})" @@ -53,9 +53,9 @@ def _run_uvicorn_server(self, llumnix.entrypoints.vllm.api_server.llumnix_client = LlumnixClientVLLM(entrypoints_context) app = llumnix.entrypoints.vllm.api_server.app - logger.info("Start api server on '{}:{}'.".format(entrypoints_args.host, entrypoints_args.port)) + logger.info("Start api server on '{}:{}'.".format(self.host, entrypoints_args.port)) uvicorn.run(app, - host=entrypoints_args.host, + host=self.host, port=entrypoints_args.port, log_level=entrypoints_args.log_level, timeout_keep_alive=llumnix.entrypoints.vllm.api_server.SERVER_TIMEOUT_KEEP_ALIVE, diff --git a/llumnix/launcher.py b/llumnix/launcher.py index ddfbe9a9..a7dd04d4 100644 --- a/llumnix/launcher.py +++ b/llumnix/launcher.py @@ -36,6 +36,7 @@ logger = init_logger(__name__) + class Launcher: def __init__(self, global_scheduler: GlobalScheduler, enable_port_increment: bool, enable_port_offset_store: bool, enable_pd_disagg: bool, @@ -58,19 +59,18 @@ def __init__(self, global_scheduler: GlobalScheduler, enable_port_increment: boo self.inflight_num_decode = 0 def init_placement_group(self, - placement_group_name: str, - engine_args, - backend_type: BackendType, - init_server: bool = False, - block: bool = True) -> PlacementGroup: + placement_group_name: str, + engine_args, + backend_type: BackendType, + init_server: bool = False, + block: bool = True) -> PlacementGroup: + # num_cpus=2+(0/1), for Llumlet + AsyncPutQueueActor + (ApiServerActor) if not BackendType.is_sim_backend(backend_type): - # num_cpus=3, for Llumlet + AsyncPutQueueActor + ProxyActor # num_gpus=world_size, for world_size Workers world_size = get_engine_world_size(engine_args, backend_type) placement_group = initialize_placement_group(placement_group_name, num_cpus=3+int(init_server), num_gpus=world_size, detached=True, block=block) else: - # num_cpus=1, for Llumlet + AsyncPutQueueActor placement_group = initialize_placement_group(placement_group_name, num_cpus=2+int(init_server), num_gpus=0, detached=True, block=block) From 41588ad157661bba3bbd143396de5cf099c67e76 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Sat, 8 Feb 2025 07:34:59 +0000 Subject: [PATCH 43/59] Minors --- llumnix/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/utils.py b/llumnix/utils.py index dd6426c5..3080dabd 100644 --- a/llumnix/utils.py +++ b/llumnix/utils.py @@ -66,7 +66,7 @@ def initialize_placement_group( "The number of required GPUs exceeds the total number of " "available GPUs in the cluster.") # Create a new placement group - # bundle_0: Llumlet + AsyncPutQueueActor + ProxyActor, bundle_1: Workers + # bundle_0: Llumlet + AsyncPutQueueActor, bundle_1: Workers placement_group_specs = ([{"CPU": num_cpus}] + [{"GPU": 1}] * num_gpus) current_placement_group = ray.util.placement_group( placement_group_specs, "STRICT_PACK", name=placement_group_name, lifetime=lifetime) From 19c9f0d97e329bd9f41bc551b9db18dd25ea39a4 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 11 Feb 2025 07:45:27 +0000 Subject: [PATCH 44/59] Simulator test done --- docs/Arguments.md | 6 +++--- llumnix/arg_utils.py | 8 +++----- llumnix/backends/utils.py | 2 +- llumnix/config/default.py | 2 +- llumnix/global_scheduler/migration_policy.py | 5 ++--- llumnix/launcher.py | 2 +- llumnix/manager.py | 5 ++++- llumnix/utils.py | 2 +- tests/unit_test/global_scheduler/test_manager.py | 9 +++++++-- .../global_scheduler/test_migration_scheduler.py | 2 +- 10 files changed, 24 insertions(+), 19 deletions(-) diff --git a/docs/Arguments.md b/docs/Arguments.md index d1c7b8aa..e5c9e0bd 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -29,7 +29,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--enable-migration] [--enable-defrag] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] - [--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}] + [--pair-migration-policy {balanced,defrag}] [--migrate-out-threshold MIGRATE_OUT_THRESHOLD] [--request-migration-policy {LCR,SR,LR,FCW,FCWSR}] [--enable-scaling] @@ -156,8 +156,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] `--pair-migration-policy` - Pair migration policy. -- Possible choices: balanced, defrag_constrained, defrag_relaxed -- Default: "defrag_constrained" +- Possible choices: balanced, defrag +- Default: "defrag" `--migrate-out-threshold` - Migrate out instance load threshold. diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index ef47e847..5fcd45ab 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -245,13 +245,11 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help='pair migration frequency') parser.add_argument('--pair-migration-policy', type=str, - choices=['balanced', 'defrag_constrained', 'defrag_relaxed'], + choices=['balanced', 'defrag'], help='The pair migration policy.\n\n' '* "balanced" pair migration to make the instance load of instance more balanced.\n' - '* "defrag_constrained" pair migration without balanced constraint to ' - 'achieve defragmentation thoroughly (with instance constraints).\n' - '* "defrag_relaxed" pair migration to without balanced constraint ' - 'to achieve defragmentation thoroughly (without instance constraints).\n') + '* "defrag" pair migration without balanced constraint to ' + 'achieve defragmentation thoroughly (with instance constraints).\n') parser.add_argument('--migrate-out-threshold', type=float, help='migrate out instance load threshold') diff --git a/llumnix/backends/utils.py b/llumnix/backends/utils.py index 997c1043..c73517e8 100644 --- a/llumnix/backends/utils.py +++ b/llumnix/backends/utils.py @@ -63,7 +63,7 @@ async def put_nowait_to_servers(self, if isinstance(ret, Exception): server_id = list(server_request_outputs.keys())[idx] server_info = server_info_dict[server_id] - logger.warning("Server {} is dead.".format(server_id)) + logger.error("Server {} is dead, exception: {}".format(server_id, ret)) if self.request_output_queue_type == QueueType.ZMQ: logger.warning("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip, server_info.request_output_queue_port)) diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 752b9d99..78be7e04 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -88,7 +88,7 @@ # Pair migration frequency _C.MANAGER.PAIR_MIGRATION_FREQUENCY = 1 # Pair migration policy -_C.MANAGER.PAIR_MIGRATION_POLICY = 'defrag_constrained' +_C.MANAGER.PAIR_MIGRATION_POLICY = 'defrag' # Migrate out instance load threshold _C.MANAGER.MIGRATE_OUT_THRESHOLD = -3.0 diff --git a/llumnix/global_scheduler/migration_policy.py b/llumnix/global_scheduler/migration_policy.py index 53427db8..120cae04 100644 --- a/llumnix/global_scheduler/migration_policy.py +++ b/llumnix/global_scheduler/migration_policy.py @@ -73,7 +73,7 @@ def pair_migration(self, return migrate_instance_pairs -class DefragConstrained(PairMigrationPolicy): +class Defrag(PairMigrationPolicy): def pair_migration(self, src_instance_infos: List[InstanceInfo], dst_instance_infos: List[InstanceInfo], @@ -82,7 +82,6 @@ def pair_migration(self, sorted_dst_instance_infos = self.sort_instance_infos(dst_instance_infos, descending=False) migrate_instance_pairs = [] for i in range(min(len(sorted_src_instance_infos), len(sorted_dst_instance_infos))): - # without any constrain in order to make prefill migrate happens as soon as possible migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id)) return migrate_instance_pairs @@ -90,7 +89,7 @@ def pair_migration(self, class PairMigrationPolicyFactory: _POLICY_REGISTRY = { 'balanced': Balanced, - 'defrag_constrained': DefragConstrained, + 'defrag': Defrag, } @classmethod diff --git a/llumnix/launcher.py b/llumnix/launcher.py index a7dd04d4..5cacf3a7 100644 --- a/llumnix/launcher.py +++ b/llumnix/launcher.py @@ -68,7 +68,7 @@ def init_placement_group(self, if not BackendType.is_sim_backend(backend_type): # num_gpus=world_size, for world_size Workers world_size = get_engine_world_size(engine_args, backend_type) - placement_group = initialize_placement_group(placement_group_name, num_cpus=3+int(init_server), + placement_group = initialize_placement_group(placement_group_name, num_cpus=2+int(init_server), num_gpus=world_size, detached=True, block=block) else: placement_group = initialize_placement_group(placement_group_name, num_cpus=2+int(init_server), diff --git a/llumnix/manager.py b/llumnix/manager.py index cdbc4e23..2baf014c 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -269,6 +269,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> self.request_instance[migrate_out_request_id] = migrate_instance_pair[1] logger.info("Instance {}->{} migrate done, migrate request {}".format( migrate_instance_pair[0], migrate_instance_pair[1], migrate_out_request_ids)) + def migrate_done_callback_wrapper(migrate_instance_pair: Tuple[str, str], fut) -> None: ret = fut.result()[0] loop = asyncio.get_event_loop() @@ -320,7 +321,9 @@ async def _auto_scale_up_loop(self, interval: float) -> None: self.scale_down(instance_id) alive_pg_states = list_placement_groups(filters=[("state", "!=", "REMOVED")]) if self.max_instances != -1 and len(alive_pg_states) >= self.max_instances: - time.sleep(interval) + logger.debug("The number of alive placement groups has reached the max_instances.") + await asyncio.sleep(interval) + continue if new_pg is None: new_instance_id = random_uuid() new_pg = self.launcher.init_placement_group(get_placement_group_name(new_instance_id), self.engine_args, self.backend_type, diff --git a/llumnix/utils.py b/llumnix/utils.py index 3080dabd..7ad45d6a 100644 --- a/llumnix/utils.py +++ b/llumnix/utils.py @@ -66,7 +66,7 @@ def initialize_placement_group( "The number of required GPUs exceeds the total number of " "available GPUs in the cluster.") # Create a new placement group - # bundle_0: Llumlet + AsyncPutQueueActor, bundle_1: Workers + # bundle_0: Llumlet + AsyncPutQueueActor, bundle_(1-num_gpus): Workers placement_group_specs = ([{"CPU": num_cpus}] + [{"GPU": 1}] * num_gpus) current_placement_group = ray.util.placement_group( placement_group_specs, "STRICT_PACK", name=placement_group_name, lifetime=lifetime) diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index c1c3d9a6..daf0eef1 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -146,9 +146,9 @@ async def get_instance_deployment_states(self, instance_id: str): return self.launcher.get_instance_deployment_states(instance_id) def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue", - enable_pd_disagg=False, pd_ratio="1:3"): + enable_pd_disagg=False, pd_ratio="1:3", max_instances=-1): manager_args = ManagerArgs(enable_port_increment=True, enable_port_offset_store=True, - enable_pd_disagg=enable_pd_disagg, pd_ratio=pd_ratio) + enable_pd_disagg=enable_pd_disagg, pd_ratio=pd_ratio, max_instances=max_instances) instance_args = InstanceArgs(migration_backend="rayrpc") entrypoints_args = EntrypointsArgs(host="127.0.0.1", port=8000, request_output_queue_type=request_output_queue_type) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) @@ -544,3 +544,8 @@ async def test_pd_disagg_deployment_states(): manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) assert not manager._inner_check_pd_deployment() +def test_auto_scale_up_loop_max_instances(ray_env): + manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, "rayqueue", max_instances=2) + time.sleep(30.0) + num_instances = ray.get(manager.scale_up.remote([], [])) + assert num_instances == 2 diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 669ce074..5f96405b 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -96,7 +96,7 @@ def test_migration_filter(pair_migration_type): assert instance.instance_type == InstanceType.DECODE assert instance.num_killed_requests == 0 -@pytest.mark.parametrize("policy", ['balanced', 'defrag_constrained']) +@pytest.mark.parametrize("policy", ['balanced', 'defrag']) def test_pair_migration(policy): num_tests = 1000 exist_migration = False From 2ad98e7f8e81e2ef7db519a367646b480db88c2d Mon Sep 17 00:00:00 2001 From: s5u13b Date: Tue, 11 Feb 2025 10:35:48 +0000 Subject: [PATCH 45/59] Fix manager unit test --- llumnix/launcher.py | 62 ++++++++++++------- llumnix/manager.py | 21 ++++--- .../global_scheduler/test_manager.py | 12 ++-- 3 files changed, 56 insertions(+), 39 deletions(-) diff --git a/llumnix/launcher.py b/llumnix/launcher.py index 5cacf3a7..6d29d3a9 100644 --- a/llumnix/launcher.py +++ b/llumnix/launcher.py @@ -38,9 +38,13 @@ class Launcher: - def __init__(self, global_scheduler: GlobalScheduler, enable_port_increment: bool, - enable_port_offset_store: bool, enable_pd_disagg: bool, - enablde_engine_pd_disagg: bool, pd_ratio: List[int]): + def __init__(self, + global_scheduler: GlobalScheduler, + enable_port_increment: bool, + enable_port_offset_store: bool, + enable_pd_disagg: bool, + enablde_engine_pd_disagg: bool, + pd_ratio: List[int]): self.global_scheduler = global_scheduler self.enable_port_increment = enable_port_increment self.enable_port_offset_store = enable_port_offset_store @@ -115,7 +119,10 @@ def clear_instance_ray_resources(self, instance_id: str): if not kill_instance(instance_id): logger.debug("Failed to kill instance {}.".format(instance_id)) - def _get_next_instance_type(self, cur_num_prefill, cur_num_decode, pd_ratio) -> str: + def _get_next_instance_type(self, + cur_num_prefill: int, + cur_num_decode: int, + pd_ratio: List[int]) -> str: instance_type = InstanceType.NO_CONSTRAINTS if self.enable_pd_disagg: @@ -148,7 +155,7 @@ def _get_next_instance_type(self, cur_num_prefill, cur_num_decode, pd_ratio) -> return instance_type - def _get_next_instance_args(self, instance_args) -> InstanceArgs: + def _get_next_instance_args(self, instance_args: InstanceArgs) -> InstanceArgs: assert not self.enablde_engine_pd_disagg, \ "Currently not support engine based pd-disaggregation in global launch mode." @@ -169,10 +176,15 @@ def _get_next_entrypoints_args(self, entrypoints_args: EntrypointsArgs) -> Entry put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) return next_entrypoints_args - def init_server_and_instance(self, instance_id: str, entrypoints_args: EntrypointsArgs, - instance_args: InstanceArgs, engine_args, backend_type: BackendType, - placement_group: PlacementGroup, instance_finish_cb: Callable = None, - server_finish_cb: Callable = None): + def init_server_and_instance(self, + instance_id: str, + entrypoints_args: EntrypointsArgs, + instance_args: InstanceArgs, + engine_args, + backend_type: BackendType, + placement_group: PlacementGroup, + instance_ready_cb: Callable = None, + server_ready_cb: Callable = None): async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: EntrypointsArgs): try: manager = ray.get_actor(get_manager_name(), namespace="llumnix") @@ -180,15 +192,15 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint await server.run.remote(manager, instance_id, instance) self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 - if instance_finish_cb: + if instance_ready_cb: # manager.scale_up will be called here after the instance is ready - instance_finish_cb(instance_id, instance, instance_args) - if server_finish_cb: - server_finish_cb(instance_id, server) - logger.info("Launcher init_server_and_instance done, instance_id: {}, instance_type: {}, " - "api_server_port: {}, request_output_queue_port: {}".format(instance_id, - instance_args.instance_type, entrypoint_args.port, - entrypoint_args.request_output_queue_port)) + instance_ready_cb(instance_id, instance, instance_args) + if server_ready_cb: + server_ready_cb(instance_id, server) + logger.info("Init server and instance done, instance_id: {}, instance_type: {}, " + "api_server_port: {}, request_output_queue_port: {}".format( + instance_id, instance_args.instance_type, + entrypoint_args.port, entrypoint_args.request_output_queue_port)) # pylint: disable=broad-except except Exception as e: self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 @@ -208,18 +220,20 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint self.inflight_num_decode += 1 if next_instance_args.instance_type == InstanceType.DECODE else 0 asyncio.create_task(done_scale_up(next_instance_args, next_entrypoints_args)) - def init_server(self, server_name: str, placement_group: PlacementGroup, + def init_server(self, + server_name: str, + placement_group: PlacementGroup, entrypoints_args: EntrypointsArgs) -> APIServerActor: fastapi_server = APIServerActor.from_args(server_name, placement_group, entrypoints_args) return fastapi_server def init_instance(self, - instance_id: str, - instance_args: InstanceArgs, - placement_group: PlacementGroup, - request_output_queue_type: QueueType, - backend_type: BackendType, - engine_args + instance_id: str, + instance_args: InstanceArgs, + placement_group: PlacementGroup, + request_output_queue_type: QueueType, + backend_type: BackendType, + engine_args ) -> Tuple[str, Llumlet]: instance = Llumlet.from_args( instance_id, diff --git a/llumnix/manager.py b/llumnix/manager.py index 2baf014c..696ad4db 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -339,7 +339,7 @@ async def _auto_scale_up_loop(self, interval: float) -> None: continue self.launcher.init_server_and_instance(new_instance_id, self.entrypoints_args, self.instance_args, self.engine_args, self.backend_type, new_pg, - instance_finish_cb=self.scale_up) + instance_ready_cb=self.scale_up) logger.info("Deploy server and instance to new placement group done, instance_id: {}.".format(new_instance_id)) # pylint: disable=broad-except except Exception as e: @@ -405,16 +405,17 @@ async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): # Restore migrate config self.enable_migration = origin_config - def scale_up(self, instance_id: Union[str, Iterable[str]], - instance_actor_handle: Union[ray.actor.ActorHandle, List[ray.actor.ActorHandle]], - instance_arg: Union[InstanceArgs, Iterable[InstanceArgs]]) -> None: + def scale_up(self, + instance_id: Union[str, Iterable[str]], + instance_actor_handle: Union[ray.actor.ActorHandle, Iterable[ray.actor.ActorHandle]], + instance_args: Union[InstanceArgs, Iterable[InstanceArgs]]) -> None: if isinstance(instance_id, str): instance_id = [instance_id,] instance_actor_handle = [instance_actor_handle,] - instance_arg = [instance_arg,] + instance_args = [instance_args,] instance_ids = list(instance_id) instance_actor_handles = list(instance_actor_handle) - instance_args = list(instance_arg) + instance_args_list = list(instance_args) indeed_update = False no_pending_instance = (self.pending_rebuild_migration_instances == 0) @@ -427,7 +428,7 @@ def scale_up(self, instance_id: Union[str, Iterable[str]], if self.log_instance_info: self.instance_last_logged_empty[ins_id] = False self.pending_rebuild_migration_instances += 1 - self.global_scheduler.scale_up(instance_ids, instance_args) + self.global_scheduler.scale_up(instance_ids, instance_args_list) self.num_instances = len(self.instances) # When scaling up, we need to rebuild the migration backend. But if initially self.pending_rebuild_migration_instances != 0, @@ -559,13 +560,13 @@ def _inner_check_pd_deployment(self) -> str: scale_down_instance_id = "" if cur_num_prefill == 0 and cur_num_decode > 0: scale_down_instance_id = random.choice(list(decode_instance_ids)) - logger.info("[_inner_check_pd_deployment] pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " + logger.info("Check pd deployment, pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " "all decode, scale down decode instance {}".format(self.manager_args.pd_ratio, cur_num_prefill, cur_num_decode, scale_down_instance_id)) if cur_num_decode == 0 and cur_num_prefill > 0: scale_down_instance_id = random.choice(list(prefill_instance_ids)) - logger.info("[_inner_check_pd_deployment] pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " + logger.info("Check pd deployment, pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " "all prefill, scale down prefill instance {}".format(self.manager_args.pd_ratio, cur_num_prefill, cur_num_decode, scale_down_instance_id)) @@ -638,7 +639,7 @@ async def watch_instance_deployment_states(instance_id: str): async def is_ready(self) -> bool: """Called by api server, return true when all the instances have been successfully created.""" tasks = [instance.is_ready.remote() for instance in self.instances.values()] - is_ready_list = await asyncio.gather(*tasks) + is_ready_list = await asyncio.gather(*tasks, return_exceptions=True) return all(is_ready_list) async def _check_instance_error(self, migrate_instance_pairs: Tuple[str, str]) -> List[bool]: diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index daf0eef1..50f353bc 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -147,8 +147,8 @@ async def get_instance_deployment_states(self, instance_id: str): def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue", enable_pd_disagg=False, pd_ratio="1:3", max_instances=-1): - manager_args = ManagerArgs(enable_port_increment=True, enable_port_offset_store=True, - enable_pd_disagg=enable_pd_disagg, pd_ratio=pd_ratio, max_instances=max_instances) + manager_args = ManagerArgs(enable_port_increment=True, enable_pd_disagg=enable_pd_disagg, + pd_ratio=pd_ratio, max_instances=max_instances) instance_args = InstanceArgs(migration_backend="rayrpc") entrypoints_args = EntrypointsArgs(host="127.0.0.1", port=8000, request_output_queue_type=request_output_queue_type) engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) @@ -544,8 +544,10 @@ async def test_pd_disagg_deployment_states(): manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) assert not manager._inner_check_pd_deployment() -def test_auto_scale_up_loop_max_instances(ray_env): + +@pytest.mark.asyncio +async def test_auto_scale_up_loop_max_instances(ray_env): manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, "rayqueue", max_instances=2) - time.sleep(30.0) - num_instances = ray.get(manager.scale_up.remote([], [])) + await asyncio.sleep(60.0) + num_instances = manager.scale_up([], [], []) assert num_instances == 2 From dd93cf002ab3a4813743b36736bb12c29e389b4e Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 01:48:58 +0000 Subject: [PATCH 46/59] Fix init_instances and simulator test --- llumnix/manager.py | 19 ++++++++++++---- .../global_scheduler/test_global_scheduler.py | 2 +- .../global_scheduler/test_manager.py | 22 ++++++------------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 696ad4db..0f6ca99a 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -547,9 +547,20 @@ def init_instances(self, instance_ids.append(instance_id) instances.append(instance) - self.scale_up(instance_ids, instances, [instance_args]*len(instance_ids)) + available_instance_ids = [] + available_instances = [] + for instance_id, instance in zip(instance_ids, instances): + try: + ray.get(instance.is_ready.remote()) + available_instance_ids.append(instance_id) + available_instances.append(instance) + except ray.exceptions.RayActorError: + logger.error("Instance {} is dead".format(instance_id)) + self.launcher.clear_instance_ray_resources(instance_id) + + self.scale_up(available_instance_ids, available_instances, [instance_args]*len(available_instance_ids)) - return instance_ids, instances + return available_instance_ids, available_instances def _inner_check_pd_deployment(self) -> str: prefill_instance_ids = self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set @@ -561,13 +572,13 @@ def _inner_check_pd_deployment(self) -> str: if cur_num_prefill == 0 and cur_num_decode > 0: scale_down_instance_id = random.choice(list(decode_instance_ids)) logger.info("Check pd deployment, pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " - "all decode, scale down decode instance {}".format(self.manager_args.pd_ratio, + "all decode instances is decode instance, scale down decode instance {}".format(self.manager_args.pd_ratio, cur_num_prefill, cur_num_decode, scale_down_instance_id)) if cur_num_decode == 0 and cur_num_prefill > 0: scale_down_instance_id = random.choice(list(prefill_instance_ids)) logger.info("Check pd deployment, pd_ratio: {}, cur_num_prefill: {}, cur_num_decode: {}, " - "all prefill, scale down prefill instance {}".format(self.manager_args.pd_ratio, + "all instances is prefill instance, scale down prefill instance {}".format(self.manager_args.pd_ratio, cur_num_prefill, cur_num_decode, scale_down_instance_id)) if scale_down_instance_id: diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 18894572..ec229c1a 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -24,7 +24,7 @@ def init_global_scheduler(): - global_scheduler_config = GlobalSchedulerConfig(0, 'load', 1, 'defrag_constrained', 3.0, + global_scheduler_config = GlobalSchedulerConfig(0, 'load', 1, 'defrag', 3.0, 'avg_load', 'remaining_steps', 10, 60, False, False) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 50f353bc..48019a08 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -106,13 +106,6 @@ def get_num_migrate_out(self): def get_num_migrate_in(self): return self.num_migrate_in -class MockBackendSim(BackendSimVLLM): - def _get_lantecy_mem(self, *args, **kwargs): - latency_mem = LatencyMemData({}, {}, {}) - latency_mem.prefill_model_params = (0,0) - latency_mem.decode_model_params = (0,0,0) - return latency_mem - def init_manager(): try: manager_args = ManagerArgs(enable_migration=True) @@ -224,15 +217,14 @@ def test_init_instances(ray_env, manager): assert num_instances == manager_args.initial_instances def test_init_instances_sim(ray_env, manager): - manager.profiling_result_file_path="//" # pylint: disable=import-outside-toplevel - import llumnix.backends.vllm.sim_llm_engine - llumnix.backends.vllm.sim_llm_engine.BackendSimVLLM = MockBackendSim - engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) - _, instances = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, InstanceArgs(), engine_args)) - num_instances = len(instances) - manager_args = ManagerArgs() - assert num_instances == manager_args.initial_instances + try: + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + _, _ = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, + InstanceArgs(profiling_result_file_path="/"), engine_args)) + assert False + except Exception as e: + assert isinstance(e, IsADirectoryError) def test_scale_up_and_down(ray_env, manager): initial_instances = 4 From 05fc34dba312777abb3bf61da2bde5a6aecb1011 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 02:03:35 +0000 Subject: [PATCH 47/59] Fix simulator test --- tests/unit_test/global_scheduler/test_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 48019a08..bdbc97bc 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -218,11 +218,11 @@ def test_init_instances(ray_env, manager): def test_init_instances_sim(ray_env, manager): # pylint: disable=import-outside-toplevel + # cannot catch by pytest.raises try: engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) _, _ = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, InstanceArgs(profiling_result_file_path="/"), engine_args)) - assert False except Exception as e: assert isinstance(e, IsADirectoryError) From 814bf96362eff7433a3d230452cc5e654e64979b Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 02:37:53 +0000 Subject: [PATCH 48/59] Minors --- llumnix/global_scheduler/migration_filter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llumnix/global_scheduler/migration_filter.py b/llumnix/global_scheduler/migration_filter.py index 3ea3c637..e5ebee76 100644 --- a/llumnix/global_scheduler/migration_filter.py +++ b/llumnix/global_scheduler/migration_filter.py @@ -67,7 +67,7 @@ def filter_instances(self, instance_infos: List[InstanceInfo], if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: policy_filter = MigrationFilterPolicyFactory.get_policy("load") elif pair_migration_type in [PairMigrationConstraints.PREFILL_2_DECODING, PairMigrationConstraints.DECODING_2_DECODING]: - policy_filter = MigrationFilterPolicyFactory.get_policy('prefill_decode') + policy_filter = MigrationFilterPolicyFactory.get_policy('pdd') else: raise ValueError(f"Unsupported pair migration type: {pair_migration_type}") src_filter_conditions.append(policy_filter.filter_src_condition(self.filter_config, pair_migration_type)) @@ -79,7 +79,7 @@ def filter_instances(self, instance_infos: List[InstanceInfo], return filtered_src_instance_infos, filtered_dst_instance_infos -class LoadConstrainedFilter(MigrationFilterPolicy): +class LoadFilter(MigrationFilterPolicy): def filter_src_condition(self, filter_config: MigrationFilterConfig, pair_migration_type: PairMigrationConstraints) -> Callable[[InstanceInfo], bool]: return lambda instance_info: instance_info.num_killed_requests > 0 \ @@ -148,8 +148,8 @@ def filter_dst_condition(self, filter_config: MigrationFilterConfig, class MigrationFilterPolicyFactory: _POLICY_REGISTRY = { - 'load': LoadConstrainedFilter, - 'prefill_decode': PddFilter, + 'load': LoadFilter, + 'pdd': PddFilter, 'custom': CustomFilter, } From 10c49aee3423c9fb96fb14cd9416e67de326aafd Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 02:46:03 +0000 Subject: [PATCH 49/59] Fix ip address --- tests/e2e_test/test_bench.py | 6 ++++-- tests/e2e_test/test_correctness.py | 4 +++- tests/e2e_test/test_migration.py | 12 +++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index 3491b3e1..72c83fd7 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -20,6 +20,8 @@ import torch import numpy as np +from llumnix.entrypoints.utils import get_ip_address + # pylint: disable=unused-import from tests.conftest import ray_env from .utils import (generate_launch_command, generate_bench_command, to_markdown_table, @@ -72,7 +74,7 @@ async def test_simple_benchmark(ray_env, shutdown_llumnix_service, model, launch else: num_prompts = 50 if not enable_pd_disagg else 50 - ip = "127.0.0.1" + ip = get_ip_address() base_port = 37037 ip_ports = [] if launch_mode == 'local': @@ -137,7 +139,7 @@ def run_bench_command(command): tasks = [] for i in range(device_count): bench_command = generate_bench_command( - ip_ports=f"127.0.0.1:{base_port + i}", + ip_ports=f"{ip}:{base_port + i}", model=model, num_prompts=num_prompts, dataset_type="sharegpt", diff --git a/tests/e2e_test/test_correctness.py b/tests/e2e_test/test_correctness.py index 64aeaf30..d5177294 100644 --- a/tests/e2e_test/test_correctness.py +++ b/tests/e2e_test/test_correctness.py @@ -18,6 +18,8 @@ import ray import torch +from llumnix.entrypoints.utils import get_ip_address + from vllm import LLM, SamplingParams # pylint: disable=unused-import @@ -85,7 +87,7 @@ async def test_correctness(ray_env, shutdown_llumnix_service, model, launch_mode await asyncio.sleep(5) # generate llumnix outputs - ip = "127.0.0.1" + ip = get_ip_address() base_port = 37037 launch_commands = [] diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index 1a4121b7..057b85bd 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -20,6 +20,8 @@ import torch import ray +from llumnix.entrypoints.utils import get_ip_address + # pylint: disable=unused-import from tests.conftest import ray_env from .utils import (generate_launch_command, generate_bench_command, to_markdown_table, @@ -44,6 +46,8 @@ def parse_instance_log_file(log_files): total_kv_cache_size = size_match.group(0).split(": ")[1].strip() speed = float(speed_match.group(1)) speed_dict[total_kv_cache_size].append(speed) + + print(speed_dict) average_speed = {} for transfer_size, speeds in speed_dict.items(): @@ -57,6 +61,8 @@ def parse_instance_log_file(log_files): average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. " + + print(average_speed) return average_speed @@ -98,7 +104,7 @@ async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, mig pytest.skip("When the migrated request status is waiting, only test the rayrpc migration backend.") request_migration_policy = 'SR' if migrated_request_status == 'running' else 'FCW' - ip = "127.0.0.1" + ip = get_ip_address() base_port = 37037 ip_ports = [] instance_output_logs = [] @@ -130,9 +136,9 @@ def run_bench_command(command): tasks = [] for i in range(device_count // 2): bench_command = generate_bench_command( - ip_ports=f"127.0.0.1:{base_port + i}", + ip_ports=f"{ip}:{base_port + i}", model=model, - num_prompts=500, + num_prompts=100, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl", qps=10, From 895d68b8f1327e5b7c2e66c1d89193f4bf29943d Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 04:36:38 +0000 Subject: [PATCH 50/59] Refine instance ready & migration size sort --- llumnix/entrypoints/setup.py | 13 ++++++++++++- llumnix/manager.py | 16 +++------------- tests/e2e_test/test_correctness.py | 2 +- tests/e2e_test/test_migration.py | 12 ++++-------- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index cc3126c5..d0e82a43 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -125,11 +125,22 @@ def init_llumnix_components(entrypoints_args: EntrypointsArgs, manager.init_instances.remote, 'init_instances', request_output_queue_type, backend_type, instance_args, engine_args) + available_instance_ids = [] + available_instances = [] + for instance_id, instance in zip(instance_ids, instances): + try: + ray.get(instance.is_ready.remote()) + available_instance_ids.append(instance_id) + available_instances.append(instance) + except: + logger.info("Instance {} is dead.".format(instance_id)) + retry_manager_method_sync(manager.scale_down.remote, 'scale_down', instance_id) + ip = get_ip_address() request_output_queue_port: str = entrypoints_args.request_output_queue_port request_output_queue = init_request_output_queue_server(ip, request_output_queue_port, request_output_queue_type) - return manager, instance_ids, instances, request_output_queue + return manager, available_instance_ids, available_instances, request_output_queue def setup_entrypoints_context(entrypoints_args, manager, instance_ids, instances, request_output_queue) -> EntrypointsContext: instances_dict: Dict[str, Llumlet] = {} diff --git a/llumnix/manager.py b/llumnix/manager.py index 0f6ca99a..65a42d91 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -547,20 +547,10 @@ def init_instances(self, instance_ids.append(instance_id) instances.append(instance) - available_instance_ids = [] - available_instances = [] - for instance_id, instance in zip(instance_ids, instances): - try: - ray.get(instance.is_ready.remote()) - available_instance_ids.append(instance_id) - available_instances.append(instance) - except ray.exceptions.RayActorError: - logger.error("Instance {} is dead".format(instance_id)) - self.launcher.clear_instance_ray_resources(instance_id) - - self.scale_up(available_instance_ids, available_instances, [instance_args]*len(available_instance_ids)) + # Because init_instances is called by multiple nodes simultaneously, we dot not wait instances ready here. + self.scale_up(instance_ids, instances, [instance_args]*len(instance_ids)) - return available_instance_ids, available_instances + return instance_ids, instances def _inner_check_pd_deployment(self) -> str: prefill_instance_ids = self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set diff --git a/tests/e2e_test/test_correctness.py b/tests/e2e_test/test_correctness.py index d5177294..162304b8 100644 --- a/tests/e2e_test/test_correctness.py +++ b/tests/e2e_test/test_correctness.py @@ -122,7 +122,7 @@ async def test_correctness(ray_env, shutdown_llumnix_service, model, launch_mode subprocess.run(launch_command, shell=True, check=True) await asyncio.sleep(3) - wait_for_llumnix_service_ready(ip_ports=[f"{ip}:{base_port}"], timeout=120) + wait_for_llumnix_service_ready(ip_ports=[f"{ip}:{base_port}"]) llumnix_output = {} for prompt in prompts: diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index 057b85bd..aa4a9ff6 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -46,8 +46,6 @@ def parse_instance_log_file(log_files): total_kv_cache_size = size_match.group(0).split(": ")[1].strip() speed = float(speed_match.group(1)) speed_dict[total_kv_cache_size].append(speed) - - print(speed_dict) average_speed = {} for transfer_size, speeds in speed_dict.items(): @@ -61,8 +59,6 @@ def parse_instance_log_file(log_files): average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. " - - print(average_speed) return average_speed @@ -97,8 +93,8 @@ def get_instance_num_blocks(): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) -@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) +@pytest.mark.parametrize("migration_backend", ['rayrpc']) +@pytest.mark.parametrize("migrated_request_status", ['running']) async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, migration_backend, migrated_request_status): if migrated_request_status == 'waiting' and migration_backend != 'rayrpc': pytest.skip("When the migrated request status is waiting, only test the rayrpc migration backend.") @@ -138,7 +134,7 @@ def run_bench_command(command): bench_command = generate_bench_command( ip_ports=f"{ip}:{base_port + i}", model=model, - num_prompts=100, + num_prompts=500, dataset_type="sharegpt", dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl", qps=10, @@ -168,7 +164,7 @@ def run_bench_command(command): if migrated_request_status == 'running': average_speed = parse_instance_log_file(instance_output_logs) - sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])) + sorted_keys = sorted(average_speed.keys(), key=lambda x: float(x.split()[0])*1024 if 'GB' in x else float(x.split()[0])) data = [ ['migration_size'] + sorted_keys, [f'{migration_backend}_speed(GB/s)'] + [f"{average_speed[key]:.2f}" for key in sorted_keys] From ed8588b29272aec2642667a9bc4d15e1b7b36ffd Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 04:38:54 +0000 Subject: [PATCH 51/59] Fix lint --- llumnix/entrypoints/setup.py | 7 +++++-- tests/e2e_test/test_correctness.py | 4 ++-- tests/unit_test/global_scheduler/test_manager.py | 3 +-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/llumnix/entrypoints/setup.py b/llumnix/entrypoints/setup.py index d0e82a43..d68bf363 100644 --- a/llumnix/entrypoints/setup.py +++ b/llumnix/entrypoints/setup.py @@ -132,8 +132,11 @@ def init_llumnix_components(entrypoints_args: EntrypointsArgs, ray.get(instance.is_ready.remote()) available_instance_ids.append(instance_id) available_instances.append(instance) - except: - logger.info("Instance {} is dead.".format(instance_id)) + # pylint: disable=broad-except + except Exception as e: + logger.error("Instance {} is dead.".format(instance_id)) + logger.error("Unexpected exception occurs: {}".format(e)) + logger.error("Exception traceback: {}".format(traceback.format_exc())) retry_manager_method_sync(manager.scale_down.remote, 'scale_down', instance_id) ip = get_ip_address() diff --git a/tests/e2e_test/test_correctness.py b/tests/e2e_test/test_correctness.py index 162304b8..8af96cf5 100644 --- a/tests/e2e_test/test_correctness.py +++ b/tests/e2e_test/test_correctness.py @@ -18,10 +18,10 @@ import ray import torch -from llumnix.entrypoints.utils import get_ip_address - from vllm import LLM, SamplingParams +from llumnix.entrypoints.utils import get_ip_address + # pylint: disable=unused-import from tests.conftest import ray_env from .utils import (generate_launch_command, generate_serve_command, wait_for_llumnix_service_ready, diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index bdbc97bc..3611fe04 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -29,9 +29,7 @@ from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType -from llumnix.backends.vllm.sim_llm_engine import BackendSimVLLM from llumnix.backends.backend_interface import BackendType -from llumnix.backends.profiling import LatencyMemData from llumnix.entrypoints.utils import LaunchMode from llumnix.utils import (get_placement_group_name, get_server_name, get_instance_name, remove_placement_group, INSTANCE_NAME_PREFIX, kill_server, @@ -223,6 +221,7 @@ def test_init_instances_sim(ray_env, manager): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) _, _ = ray.get(manager.init_instances.remote(QueueType("rayqueue"), BackendType.SIM_VLLM, InstanceArgs(profiling_result_file_path="/"), engine_args)) + # pylint: disable=broad-except except Exception as e: assert isinstance(e, IsADirectoryError) From 73f0f15325c184f299d20a0166c0961de4c37281 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 06:05:16 +0000 Subject: [PATCH 52/59] Refine timestamps --- llumnix/backends/vllm/llm_engine.py | 10 +++++----- llumnix/entrypoints/bladellm/client.py | 3 ++- llumnix/entrypoints/vllm/client.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index 0263fb4c..ad01d54c 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -176,17 +176,17 @@ def _process_model_outputs(self, if ctx.request_outputs: request_outputs, server_infos = zip(*ctx.request_outputs) + for request_output, server_info in zip(request_outputs, server_infos): if hasattr(server_info, 'request_timestamps'): request_output.request_timestamps = server_info.request_timestamps - request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time() + set_timestamp(request_outputs, 'engine_process_model_outputs_timestamp_end', time.time()) return def _process_request_outputs( self, outputs: List[Tuple[RequestOutput, ServerInfo]], - step_begin_time: float ) -> Tuple[List[RequestOutput], List[ServerInfo]]: request_outputs = [] server_infos = [] @@ -195,7 +195,7 @@ def _process_request_outputs( request_outputs = list(request_outputs) server_infos = list(server_infos) - set_timestamp(request_outputs, 'engine_step_timestamp_begin', step_begin_time) + set_timestamp(request_outputs, 'engine_step_timestamp_begin', self.step_begin_time) set_timestamp(request_outputs, 'engine_step_timestamp_end', time.time()) for request_output in request_outputs: @@ -232,10 +232,10 @@ def _process_request_outputs( return request_outputs, server_infos async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]: - step_begin_time = time.time() + self.step_begin_time = time.time() # pylint: disable=too-many-function-args outputs = await super().step_async(0) - return self._process_request_outputs(outputs, step_begin_time) + return self._process_request_outputs(outputs) def update_instance_info(self, instance_info: InstanceInfo) -> None: # These fields are updated after step. diff --git a/llumnix/entrypoints/bladellm/client.py b/llumnix/entrypoints/bladellm/client.py index fe95cc21..a06cf027 100644 --- a/llumnix/entrypoints/bladellm/client.py +++ b/llumnix/entrypoints/bladellm/client.py @@ -31,6 +31,7 @@ from llumnix.entrypoints.utils import EntrypointsContext from llumnix.logging.logger import init_logger from llumnix.constants import WAIT_MANAGER_INTERVAL +from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -87,7 +88,7 @@ async def _manager_generate(self, request, request_id: str) -> LLMResponse: if self.llumnix_context.log_request_timestamps: # Hack request timestamps in server_info for latency breakdown. server_info_copy.request_timestamps = RequestTimestamps() - server_info_copy.request_timestamps.api_server_generate_timestamp = time.time() + set_timestamp(server_info_copy, "api_server_generate_timestamp", time.time()) # await to catch exception await self.llumnix_context.manager.generate.remote(str(request_id), server_info_copy, server_request=request) self.llumnix_context.manager_available = True diff --git a/llumnix/entrypoints/vllm/client.py b/llumnix/entrypoints/vllm/client.py index 0b10fa44..e6ec713c 100644 --- a/llumnix/entrypoints/vllm/client.py +++ b/llumnix/entrypoints/vllm/client.py @@ -73,7 +73,7 @@ async def _generate_by_manager(self, if self.log_request_timestamps: # Hack request timestamps in server_info for latency breakdown. server_info.request_timestamps = RequestTimestamps() - server_info.request_timestamps.api_server_generate_timestamp = time.time() + set_timestamp(server_info, "api_server_generate_timestamp", time.time()) await self.manager.generate.remote(request_id, server_info, prompt, sampling_params, *args, **kwargs) async def _generate_by_instance(self, From 59f2b085caa5c3c36f5ac121280966b40a5d6599 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 07:43:33 +0000 Subject: [PATCH 53/59] Resort manager and launcher functions & Fix test_manager --- llumnix/arg_utils.py | 4 + llumnix/instance_info.py | 7 + llumnix/internal_config.py | 52 +-- llumnix/launcher.py | 191 ++++---- llumnix/manager.py | 414 ++++++++++-------- tests/e2e_test/test_migration.py | 4 +- .../global_scheduler/test_manager.py | 60 ++- 7 files changed, 372 insertions(+), 360 deletions(-) diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 5fcd45ab..f819ac55 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -44,6 +44,7 @@ def add_argument(self, *args, **kwargs): kwargs['default'] = None super().add_argument(*args, **kwargs) + @dataclass class EntrypointsArgs: host: str = None @@ -112,6 +113,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help="path to config file of arguments") return parser + @dataclass class ManagerArgs: initial_instances: int = None @@ -298,11 +300,13 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: help='the prefill decode ratio used in gloabl launch model e.g. "1:1"') return parser + @dataclass class LaunchArgs: launch_mode: LaunchMode = None backend_type: BackendType = None + @dataclass class InstanceArgs: instance_type: str = None diff --git a/llumnix/instance_info.py b/llumnix/instance_info.py index c79049c0..0070a8a3 100644 --- a/llumnix/instance_info.py +++ b/llumnix/instance_info.py @@ -23,11 +23,13 @@ logger = init_logger(__name__) + class InstanceType(str, Enum): NO_CONSTRAINTS = "no_constraints" PREFILL = "prefill" DECODE = "decode" + @dataclass class InstanceInfo: instance_id: str = "" @@ -71,6 +73,7 @@ def __post_init__(self) -> None: self.num_available_gpu_blocks = self.num_free_gpu_blocks - self.num_watermark_blocks self.num_available_gpu_blocks_waiting = self.num_available_gpu_blocks - self.num_blocks_all_waiting_requests + class InstanceLoadCalculator: def __init__(self, dispatch_load_metric: str, migration_load_metric: str, enable_defrag: bool) -> None: self.dispatch_load_calculator = DispatchLoadComputation(migration_load_metric) @@ -84,6 +87,7 @@ def compute_instance_load(self, instance_info: InstanceInfo): instance_info.migration_load_metric_after_migrate_in = self.migration_load_calculator.\ compute_instance_load_after_migrate(instance_info, is_migrate_in=True) + class LoadComputationStrategy(ABC): def __init__(self, load_metric: str, enable_defrag: bool = False) -> None: self.load_metric = load_metric @@ -93,6 +97,7 @@ def __init__(self, load_metric: str, enable_defrag: bool = False) -> None: def compute_instance_load(self, instance_info: InstanceInfo) -> float: pass + class DispatchLoadComputation(LoadComputationStrategy): def compute_instance_load(self, instance_info: InstanceInfo) -> float: instance_load = -np.inf @@ -107,6 +112,7 @@ def compute_instance_load(self, instance_info: InstanceInfo) -> float: instance_load = (num_available_gpu_blocks / num_requests)*(-1) return instance_load + class MigrationLoadComputation(LoadComputationStrategy): def compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float: instance_info_after_migrate = copy.deepcopy(instance_info) @@ -141,6 +147,7 @@ def compute_instance_load(self, instance_info: InstanceInfo) -> float: instance_load = (num_available_gpu_blocks / num_requests) * (-1) return instance_load + # TODO(KuilongCui): currently scaling and dispatch use the same load calculator, leave # it in the future to refine class ScalingLoadComputation(LoadComputationStrategy): diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 03a0a5be..4b02f8db 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -11,32 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -class MigrationConfig: - def __init__( - self, - request_migration_policy: str, - migration_backend: str, - migration_buffer_blocks: int, - migration_num_layers: int, - migration_last_stage_max_blocks: int, - migration_max_stages: int, - migration_backend_init_timeout: float, - migration_backend_transfer_type: str = "", - grpc_migration_backend_server_address: str = "", - kvtransfer_migration_backend_naming_url: str = "", - ) -> None: - self.request_migration_policy = request_migration_policy - self.migration_backend = migration_backend - self.migration_backend_transfer_type = migration_backend_transfer_type - self.migration_num_layers = migration_num_layers - self.migration_buffer_blocks = migration_buffer_blocks - self.migration_last_stage_max_blocks = migration_last_stage_max_blocks - self.migration_max_stages = migration_max_stages - self.migration_backend_init_timeout = migration_backend_init_timeout - self.grpc_migration_backend_server_address = grpc_migration_backend_server_address - self.kvtransfer_migration_backend_naming_url = kvtransfer_migration_backend_naming_url - - class GlobalSchedulerConfig: def __init__( self, @@ -65,3 +39,29 @@ def __init__( self.enable_pd_disagg = enable_pd_disagg self.is_group_kind_migration_backend = is_group_kind_migration_backend + + +class MigrationConfig: + def __init__( + self, + request_migration_policy: str, + migration_backend: str, + migration_buffer_blocks: int, + migration_num_layers: int, + migration_last_stage_max_blocks: int, + migration_max_stages: int, + migration_backend_init_timeout: float, + migration_backend_transfer_type: str = "", + grpc_migration_backend_server_address: str = "", + kvtransfer_migration_backend_naming_url: str = "", + ) -> None: + self.request_migration_policy = request_migration_policy + self.migration_backend = migration_backend + self.migration_backend_transfer_type = migration_backend_transfer_type + self.migration_num_layers = migration_num_layers + self.migration_buffer_blocks = migration_buffer_blocks + self.migration_last_stage_max_blocks = migration_last_stage_max_blocks + self.migration_max_stages = migration_max_stages + self.migration_backend_init_timeout = migration_backend_init_timeout + self.grpc_migration_backend_server_address = grpc_migration_backend_server_address + self.kvtransfer_migration_backend_naming_url = kvtransfer_migration_backend_naming_url diff --git a/llumnix/launcher.py b/llumnix/launcher.py index 6d29d3a9..73c741dd 100644 --- a/llumnix/launcher.py +++ b/llumnix/launcher.py @@ -14,10 +14,9 @@ import asyncio import copy import traceback -from typing import Callable, Dict, List, Tuple +from typing import Callable, List, Tuple import ray -from ray.util.state import list_placement_groups, list_actors from ray.util.placement_group import PlacementGroup from llumnix.logging.logger import init_logger @@ -29,10 +28,10 @@ from llumnix.arg_utils import EntrypointsArgs, InstanceArgs from llumnix.entrypoints.vllm.api_server_actor import APIServerActor from llumnix.backends.utils import get_engine_world_size -from llumnix.utils import (remove_placement_group, get_manager_name, INSTANCE_NAME_PREFIX, get_instance_name, - SERVER_NAME_PREFIX, kill_server, kill_instance, get_actor_data_from_ray_internal_kv, - initialize_placement_group, get_server_name, put_actor_data_to_ray_internal_kv, - get_placement_group_name) +from llumnix.utils import (initialize_placement_group, remove_placement_group, + get_manager_name, get_server_name, + kill_server, kill_instance, + get_actor_data_from_ray_internal_kv, put_actor_data_to_ray_internal_kv) logger = init_logger(__name__) @@ -59,8 +58,8 @@ def __init__(self, if value is not None: self.port_offset = int(value) - self.inflight_num_prefill = 0 - self.inflight_num_decode = 0 + self.inflight_num_prefill_instance = 0 + self.inflight_num_decode_instance = 0 def init_placement_group(self, placement_group_name: str, @@ -80,102 +79,6 @@ def init_placement_group(self, return placement_group - def get_instance_deployment_states(self, instance_id: str): - pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))]) - pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED" - server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))]) - server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE" - instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) - instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE" - - return pg_created, server_alive, instance_alive - - def get_cluster_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, APIServerActor], Dict[str, Llumlet]]: - curr_pgs: Dict[str, PlacementGroup] = {} - curr_servers: Dict[str, PlacementGroup] = {} - curr_instances: Dict[str, Llumlet] = {} - - created_pg_states = list_placement_groups(filters=[("state", "=", "CREATED")]) - for created_pg_state in created_pg_states: - instance_id = created_pg_state["name"].split("_")[-1] - curr_pgs[instance_id] = ray.util.get_placement_group(created_pg_state["name"]) - - alive_actor_states = list_actors(filters=[("state", "=", "ALIVE")]) - for alive_actor_state in alive_actor_states: - if alive_actor_state["name"].startswith(SERVER_NAME_PREFIX): - instance_id = alive_actor_state["name"].split("_")[-1] - curr_servers[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") - elif alive_actor_state["name"].startswith(INSTANCE_NAME_PREFIX): - instance_id = alive_actor_state["name"].split("_")[-1] - curr_instances[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") - - return curr_pgs, curr_servers, curr_instances - - def clear_instance_ray_resources(self, instance_id: str): - if not remove_placement_group(instance_id): - logger.debug("Failed to remove placement group {}.".format(instance_id)) - if not kill_server(instance_id): - logger.debug("Failed to kill server {}.".format(instance_id)) - if not kill_instance(instance_id): - logger.debug("Failed to kill instance {}.".format(instance_id)) - - def _get_next_instance_type(self, - cur_num_prefill: int, - cur_num_decode: int, - pd_ratio: List[int]) -> str: - instance_type = InstanceType.NO_CONSTRAINTS - - if self.enable_pd_disagg: - # Note: There are no instances simultaneously in inflight_num_prefill and cur_num_prefill as - # inflight_num will decrease before scaling up the instances. The same applies to num_decode. - total_num_prefill = self.inflight_num_prefill + cur_num_prefill - total_num_decode = self.inflight_num_decode + cur_num_decode - - if total_num_prefill == 0: - instance_type = InstanceType.PREFILL - elif total_num_decode == 0: - instance_type = InstanceType.DECODE - else: - # compute distance if launch prefill or decode - normal_distance = pd_ratio[0] - pd_ratio[1] - - base_num_ratio = min(total_num_prefill//pd_ratio[0], total_num_decode//pd_ratio[1]) - total_num_prefill = total_num_prefill - base_num_ratio * pd_ratio[0] - total_num_decode = total_num_decode - base_num_ratio * pd_ratio[1] - - if total_num_prefill + total_num_decode == 0: - instance_type = InstanceType.PREFILL - else: - distance_if_prefill = total_num_prefill + 1 - total_num_decode - distance_if_decode = total_num_prefill - (total_num_decode + 1) - gap_to_normal_if_prefill = abs(distance_if_prefill - normal_distance) - gap_to_normal_if_decode = abs(distance_if_decode - normal_distance) - instance_type = InstanceType.PREFILL if gap_to_normal_if_prefill <= gap_to_normal_if_decode \ - else InstanceType.DECODE - - return instance_type - - def _get_next_instance_args(self, instance_args: InstanceArgs) -> InstanceArgs: - assert not self.enablde_engine_pd_disagg, \ - "Currently not support engine based pd-disaggregation in global launch mode." - - next_instance_args: InstanceArgs = copy.deepcopy(instance_args) - cur_num_prefill = len(self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) - cur_num_decode = len(self.global_scheduler.instance_id_set - - self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) - next_instance_args.instance_type = self._get_next_instance_type(cur_num_prefill, cur_num_decode, self.pd_ratio) - return next_instance_args - - def _get_next_entrypoints_args(self, entrypoints_args: EntrypointsArgs) -> EntrypointsArgs: - next_entrypoints_args = copy.deepcopy(entrypoints_args) - if self.enable_port_increment: - next_entrypoints_args.port += self.port_offset - next_entrypoints_args.request_output_queue_port += self.port_offset - self.port_offset += 1 - if self.enable_port_offset_store: - put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) - return next_entrypoints_args - def init_server_and_instance(self, instance_id: str, entrypoints_args: EntrypointsArgs, @@ -190,8 +93,8 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint manager = ray.get_actor(get_manager_name(), namespace="llumnix") await instance.is_ready.remote() await server.run.remote(manager, instance_id, instance) - self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 - self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 + self.inflight_num_prefill_instance -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode_instance -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 if instance_ready_cb: # manager.scale_up will be called here after the instance is ready instance_ready_cb(instance_id, instance, instance_args) @@ -203,8 +106,8 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint entrypoint_args.port, entrypoint_args.request_output_queue_port)) # pylint: disable=broad-except except Exception as e: - self.inflight_num_prefill -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 - self.inflight_num_decode -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 + self.inflight_num_prefill_instance -= 1 if instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode_instance -= 1 if instance_args.instance_type == InstanceType.DECODE else 0 logger.error("Unexpected exception occurs: {}".format(e)) logger.error("Exception traceback: {}".format(traceback.format_exc())) self.clear_instance_ray_resources(instance_id) @@ -216,8 +119,8 @@ async def done_scale_up(instance_args: InstanceArgs, entrypoint_args: Entrypoint next_entrypoints_args = self._get_next_entrypoints_args(entrypoints_args) server = self.init_server(get_server_name(instance_id), placement_group, next_entrypoints_args) - self.inflight_num_prefill += 1 if next_instance_args.instance_type == InstanceType.PREFILL else 0 - self.inflight_num_decode += 1 if next_instance_args.instance_type == InstanceType.DECODE else 0 + self.inflight_num_prefill_instance += 1 if next_instance_args.instance_type == InstanceType.PREFILL else 0 + self.inflight_num_decode_instance += 1 if next_instance_args.instance_type == InstanceType.DECODE else 0 asyncio.create_task(done_scale_up(next_instance_args, next_entrypoints_args)) def init_server(self, @@ -225,6 +128,7 @@ def init_server(self, placement_group: PlacementGroup, entrypoints_args: EntrypointsArgs) -> APIServerActor: fastapi_server = APIServerActor.from_args(server_name, placement_group, entrypoints_args) + return fastapi_server def init_instance(self, @@ -244,3 +148,70 @@ def init_instance(self, engine_args) return instance + + def clear_instance_ray_resources(self, instance_id: str): + if not remove_placement_group(instance_id): + logger.debug("Failed to remove placement group {}.".format(instance_id)) + if not kill_server(instance_id): + logger.debug("Failed to kill server {}.".format(instance_id)) + if not kill_instance(instance_id): + logger.debug("Failed to kill instance {}.".format(instance_id)) + + def _get_next_instance_args(self, instance_args: InstanceArgs) -> InstanceArgs: + assert not self.enablde_engine_pd_disagg, \ + "Currently not support engine based pd-disaggregation in global launch mode." + + next_instance_args: InstanceArgs = copy.deepcopy(instance_args) + cur_num_prefill = len(self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) + cur_num_decode = len(self.global_scheduler.instance_id_set - + self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set) + next_instance_args.instance_type = self._get_next_instance_type(cur_num_prefill, cur_num_decode, self.pd_ratio) + + return next_instance_args + + def _get_next_entrypoints_args(self, entrypoints_args: EntrypointsArgs) -> EntrypointsArgs: + next_entrypoints_args = copy.deepcopy(entrypoints_args) + if self.enable_port_increment: + next_entrypoints_args.port += self.port_offset + next_entrypoints_args.request_output_queue_port += self.port_offset + self.port_offset += 1 + if self.enable_port_offset_store: + put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset) + + return next_entrypoints_args + + def _get_next_instance_type(self, + cur_num_prefill: int, + cur_num_decode: int, + pd_ratio: List[int]) -> str: + instance_type = InstanceType.NO_CONSTRAINTS + + if self.enable_pd_disagg: + # Note: There are no instances simultaneously in inflight_num_prefill and cur_num_prefill as + # inflight_num will decrease before scaling up the instances. The same applies to num_decode. + total_num_prefill = self.inflight_num_prefill_instance + cur_num_prefill + total_num_decode = self.inflight_num_decode_instance + cur_num_decode + + if total_num_prefill == 0: + instance_type = InstanceType.PREFILL + elif total_num_decode == 0: + instance_type = InstanceType.DECODE + else: + # compute distance if launch prefill or decode + normal_distance = pd_ratio[0] - pd_ratio[1] + + base_num_ratio = min(total_num_prefill//pd_ratio[0], total_num_decode//pd_ratio[1]) + total_num_prefill = total_num_prefill - base_num_ratio * pd_ratio[0] + total_num_decode = total_num_decode - base_num_ratio * pd_ratio[1] + + if total_num_prefill + total_num_decode == 0: + instance_type = InstanceType.PREFILL + else: + distance_if_prefill = total_num_prefill + 1 - total_num_decode + distance_if_decode = total_num_prefill - (total_num_decode + 1) + gap_to_normal_if_prefill = abs(distance_if_prefill - normal_distance) + gap_to_normal_if_decode = abs(distance_if_decode - normal_distance) + instance_type = InstanceType.PREFILL if gap_to_normal_if_prefill <= gap_to_normal_if_decode \ + else InstanceType.DECODE + + return instance_type diff --git a/llumnix/manager.py b/llumnix/manager.py index 65a42d91..4abab584 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -24,6 +24,7 @@ import ray import ray.actor from ray.util.state import list_placement_groups, list_actors +from ray.util.placement_group import PlacementGroup from llumnix.llumlet.llumlet import Llumlet from llumnix.logging.logger import init_logger @@ -34,9 +35,9 @@ from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, InstanceArgs, LaunchArgs from llumnix.server_info import ServerInfo from llumnix.backends.backend_interface import BackendType -from llumnix.utils import (random_uuid, clear_gloo_backend_state, get_instance_name, - get_manager_name, INSTANCE_NAME_PREFIX, get_placement_group_name, - run_async_func_sync,) +from llumnix.utils import (random_uuid, clear_gloo_backend_state, get_server_name, + get_instance_name, get_manager_name, get_placement_group_name, + INSTANCE_NAME_PREFIX, SERVER_NAME_PREFIX, run_async_func_sync) from llumnix.entrypoints.utils import LaunchMode from llumnix.queue.queue_type import QueueType from llumnix.constants import (CLEAR_REQUEST_INSTANCE_INTERVAL, NO_INSTANCE_RETRY_GENERATE_INTERVAL, @@ -45,10 +46,12 @@ WATCH_DEPLOYMENT_INTERVAL, WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE) from llumnix.launcher import Launcher from llumnix.metrics.timestamps import set_timestamp +from llumnix.entrypoints.vllm.api_server_actor import APIServerActor logger = init_logger(__name__) # TODO(s5u13b): Handle exception of ray operations. +# TODO(s5u13b): Refactor manager to divide functions into different classes. class Manager: @@ -195,6 +198,55 @@ def abort_done_callback(instance_id: str, request_ids: List[str], fut): tasks.append(task) await asyncio.gather(*tasks, return_exceptions=True) + @classmethod + def from_args(cls, + entrypoints_args: EntrypointsArgs, + manager_args: ManagerArgs, + instance_args: InstanceArgs, + engine_args, + launch_args: LaunchArgs, + ) -> "Manager": + manager_class = ray.remote(num_cpus=1, + max_restarts=-1, + name=get_manager_name(), + namespace="llumnix", + lifetime="detached")(cls) + manager = manager_class.remote( + entrypoints_args, + manager_args, + instance_args, + engine_args, + launch_args, + os.getcwd()) + return manager + + def init_instances(self, + request_output_queue_type: QueueType, + backend_type: BackendType, + instance_args: InstanceArgs, + engine_args + ) -> Tuple[List[str], List[Llumlet]]: + instance_ids: List[str] = [] + instances: List[Llumlet] = [] + for _ in range(self.manager_args.initial_instances): + instance_id = random_uuid() + placement_group = self.launcher.init_placement_group(get_placement_group_name(instance_id), engine_args, backend_type) + instance = self.launcher.init_instance(instance_id, instance_args, placement_group, request_output_queue_type, + backend_type, engine_args) + instance_ids.append(instance_id) + instances.append(instance) + + # Because init_instances is called by multiple nodes simultaneously, we dot not wait instances ready here. + self.scale_up(instance_ids, instances, [instance_args]*len(instance_ids)) + + return instance_ids, instances + + async def is_ready(self) -> bool: + """Called by api server, return true when all the instances have been successfully created.""" + tasks = [instance.is_ready.remote() for instance in self.instances.values()] + is_ready_list = await asyncio.gather(*tasks, return_exceptions=True) + return all(is_ready_list) + async def _poll_instance_info_loop(self, interval: float) -> None: def get_instance_info_done_callback(instance_id: str, fut): ret = fut.result()[0] @@ -346,65 +398,6 @@ async def _auto_scale_up_loop(self, interval: float) -> None: logger.error("Unexpected exception: {}".format(e)) logger.error("Exception traceback: {}".format(traceback.format_exc())) - # TODO(KuilongCui): Add comments for this function. - async def _rebuild_migration_backend(self) -> None: - # Wait for all instances to finish migration - while any(self.instance_migrating.values()): - await asyncio.sleep(WAIT_ALL_MIGRATIONS_DONE_INTERVAL) - - # During rebuilding migration backend, disable migration. - origin_config = self.enable_migration - self.enable_migration = False - - async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): - tasks = [] - for instance_name in alive_instances: - llumlet_handle = self.instances[instance_name] - tasks.append(llumlet_handle.execute_engine_method.remote("_run_workers", task_name, *args, **kwargs)) - rets = await asyncio.gather(*tasks, return_exceptions=True) - dead_instances = set() - for instance_name, ret in zip(alive_instances, rets): - if isinstance(ret, ray.exceptions.RayActorError): - dead_instances.add(instance_name) - if len(dead_instances) > 0: - self.scale_down(dead_instances, rebuild_migration_backend=False) - clear_gloo_backend_state() - return dead_instances - - alive_instances = sorted(self.instances.keys()) - pending_task = self.pending_rebuild_migration_instances - group_name = None - clear_gloo_backend_state() - - while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0: - dead_instances = set() - group_name = random_uuid() - instance_rank = {instance_id: index for index, instance_id in enumerate(alive_instances)} - dead_instances.update(await run_task(alive_instances, "rebuild_migration_backend", - instance_rank, group_name)) - if len(dead_instances) == 0 and self.pending_rebuild_migration_instances == pending_task: - dead_instances.update(await run_task(alive_instances, "warmup")) - if len(dead_instances) == 0: - self.pending_rebuild_migration_instances -= pending_task - alive_instances = sorted(set(self.instances.keys()) - dead_instances) - pending_task = self.pending_rebuild_migration_instances - - if len(alive_instances) == 0: - self.pending_rebuild_migration_instances = 0 - group_name = None - - migration_filter: CustomFilter = self.global_scheduler.migration_scheduler\ - .migration_filter.get_filter("migration_backend_init_filter") - migration_filter.set_filter_condtition( - src_filter=lambda instance_info: instance_info.instance_id in alive_instances, - dst_filter=lambda instance_info: instance_info.instance_id in alive_instances) - - logger.info("Rebuild migration backend done, group_name: {}, alive instance ({}): {}." - .format(group_name, len(alive_instances), alive_instances)) - - # Restore migrate config - self.enable_migration = origin_config - def scale_up(self, instance_id: Union[str, Iterable[str]], instance_actor_handle: Union[ray.actor.ActorHandle, Iterable[ray.actor.ActorHandle]], @@ -479,80 +472,68 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b return self.num_instances - async def _connect_to_instances(self): - def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: "ray.actor.ActorHandle", fut): - ret = fut.result()[0] - if not isinstance(ret, Exception): - scale_up_instance_ids.append(instance_id) - scale_up_instance_actor_handles.append(instance_actor_handle) - scale_up_instance_args.append(ret) - logger.info("Connect to instance {}".format(instance_id)) - else: - logger.warning("Connect to instance {} failed, exception: {}".format(instance_id, ret)) + async def _check_deployment_states_loop(self, interval: float) -> None: + async def watch_instance_deployment_states(instance_id: str): + # There might be some delays of calling _init_server_and_instance, so sleep first. + await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL) + wait_pending_instance_time = 0.0 + while True: + instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) + instance_pending_creation = len(instance_state) == 1 and instance_state[0]["state"] == "PENDING_CREATION" + if not instance_pending_creation: + break + await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL) + wait_pending_instance_time += WATCH_DEPLOYMENT_INTERVAL + if wait_pending_instance_time >= WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE: + break + pg_created, server_alive, instance_alive = self._get_instance_deployment_states(instance_id) + if pg_created and (not server_alive or not instance_alive): + logger.warning("Instance {} deployment states incorrect, states: (pg {}, server {}, instance {})" + .format(instance_id, pg_created, server_alive, instance_alive)) + self.scale_down(instance_id) - # Must set True despite set namespance to llumnix. - actor_names_dict = ray.util.list_named_actors(all_namespaces=True) - instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict - if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] - instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names] - scale_up_instance_ids = [] - scale_up_instance_actor_handles = [] - scale_up_instance_args = [] - tasks = [] - for instance_actor_name, instance_actor_handle in zip(instance_actor_names, instance_actor_handles): - instance_id = instance_actor_name[len('instance_'):] - if instance_id not in self.instances: - task = asyncio.gather(instance_actor_handle.get_instance_args.remote(), return_exceptions=True) - task.add_done_callback(partial(connect_to_instances_done_callback, instance_id, instance_actor_handle)) - tasks.append(task) - await asyncio.gather(*tasks) - # The only function that can add instance actor handles to manager. - self.scale_up(scale_up_instance_ids, scale_up_instance_actor_handles, scale_up_instance_args) + while True: + try: + curr_pgs, curr_servers, curr_instances = self._get_cluster_deployment_states() + assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances)) + tasks = [] + for instance_id in curr_pgs: + if instance_id not in curr_servers or instance_id not in curr_instances: + tasks.append(asyncio.create_task(watch_instance_deployment_states(instance_id))) + await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.sleep(interval) + # pylint: disable=broad-except + except Exception as e: + logger.error("Unexpected exception: {}".format(e)) + logger.error("Exception traceback: {}".format(traceback.format_exc())) - @classmethod - def from_args(cls, - entrypoints_args: EntrypointsArgs, - manager_args: ManagerArgs, - instance_args: InstanceArgs, - engine_args, - launch_args: LaunchArgs, - ) -> "Manager": - manager_class = ray.remote(num_cpus=1, - max_restarts=-1, - name=get_manager_name(), - namespace="llumnix", - lifetime="detached")(cls) - manager = manager_class.remote( - entrypoints_args, - manager_args, - instance_args, - engine_args, - launch_args, - os.getcwd()) - return manager + # TODO(KuilongCui): Currently, only one naive state check policy is implemented, + # which prevents the cluster from consisting entirely of prefill or decode instances. + async def _check_pd_deployment_states_loop(self, interval: float) -> None: + previous_penging_pg_names = None - def init_instances(self, - request_output_queue_type: QueueType, - backend_type: BackendType, - instance_args: InstanceArgs, - engine_args - ) -> Tuple[List[str], List[Llumlet]]: - instance_ids: List[str] = [] - instances: List[Llumlet] = [] - for _ in range(self.manager_args.initial_instances): - instance_id = random_uuid() - placement_group = self.launcher.init_placement_group(get_placement_group_name(instance_id), engine_args, backend_type) - instance = self.launcher.init_instance(instance_id, instance_args, placement_group, request_output_queue_type, - backend_type, engine_args) - instance_ids.append(instance_id) - instances.append(instance) + while True: + try: + pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) + rescheduling_pg_states = list_placement_groups(filters=[("state", "=", "RESCHEDULING")]) + all_penging_pg_names = [pg.name for pg in pending_pg_states] - # Because init_instances is called by multiple nodes simultaneously, we dot not wait instances ready here. - self.scale_up(instance_ids, instances, [instance_args]*len(instance_ids)) + if previous_penging_pg_names and len(rescheduling_pg_states) == 0 : + new_pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) + all_new_penging_pg_names = [pg.name for pg in new_pending_pg_states] + if len(set(previous_penging_pg_names).difference(set(all_new_penging_pg_names))) == 0: + self._check_pd_deployment_states() + previous_penging_pg_names = all_new_penging_pg_names + else: + previous_penging_pg_names = all_penging_pg_names - return instance_ids, instances + await asyncio.sleep(interval) + # pylint: disable=broad-except + except Exception as e: + logger.error("Unexpected exception: {}".format(e)) + logger.error("Exception traceback: {}".format(traceback.format_exc())) - def _inner_check_pd_deployment(self) -> str: + def _check_pd_deployment_states(self) -> str: prefill_instance_ids = self.global_scheduler.dispatch_scheduler.available_dispatch_instance_set cur_num_prefill = len(prefill_instance_ids) decode_instance_ids = self.global_scheduler.instance_id_set - prefill_instance_ids @@ -576,72 +557,125 @@ def _inner_check_pd_deployment(self) -> str: return scale_down_instance_id - # TODO(KuilongCui): currently, only one naive state check policy is implemented, which prevents the - # cluster from consisting entirely of prefill or decode instances. - async def _check_pd_deployment_states_loop(self, interval: float) -> None: - previous_penging_pg_names = None + def _get_cluster_deployment_states(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, APIServerActor], Dict[str, Llumlet]]: + curr_pgs: Dict[str, PlacementGroup] = {} + curr_servers: Dict[str, PlacementGroup] = {} + curr_instances: Dict[str, Llumlet] = {} + + created_pg_states = list_placement_groups(filters=[("state", "=", "CREATED")]) + for created_pg_state in created_pg_states: + instance_id = created_pg_state["name"].split("_")[-1] + curr_pgs[instance_id] = ray.util.get_placement_group(created_pg_state["name"]) + + alive_actor_states = list_actors(filters=[("state", "=", "ALIVE")]) + for alive_actor_state in alive_actor_states: + if alive_actor_state["name"].startswith(SERVER_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_servers[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") + elif alive_actor_state["name"].startswith(INSTANCE_NAME_PREFIX): + instance_id = alive_actor_state["name"].split("_")[-1] + curr_instances[instance_id] = ray.get_actor(alive_actor_state["name"], namespace="llumnix") + + return curr_pgs, curr_servers, curr_instances + + def _get_instance_deployment_states(self, instance_id: str): + pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))]) + pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED" + server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))]) + server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE" + instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) + instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE" + + return pg_created, server_alive, instance_alive - while True: - try: - pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) - rescheduling_pg_states = list_placement_groups(filters=[("state", "=", "RESCHEDULING")]) - all_penging_pg_names = [pg.name for pg in pending_pg_states] + # TODO(KuilongCui): Add comments for this function. + async def _rebuild_migration_backend(self) -> None: + # Wait for all instances to finish migration + while any(self.instance_migrating.values()): + await asyncio.sleep(WAIT_ALL_MIGRATIONS_DONE_INTERVAL) - if previous_penging_pg_names and len(rescheduling_pg_states) == 0 : - new_pending_pg_states = list_placement_groups(filters=[("state", "=", "PENDING")]) - all_new_penging_pg_names = [pg.name for pg in new_pending_pg_states] - if len(set(previous_penging_pg_names).difference(set(all_new_penging_pg_names))) == 0: - self._inner_check_pd_deployment() - previous_penging_pg_names = all_new_penging_pg_names - else: - previous_penging_pg_names = all_penging_pg_names + # During rebuilding migration backend, disable migration. + origin_config = self.enable_migration + self.enable_migration = False - await asyncio.sleep(interval) - # pylint: disable=broad-except - except Exception as e: - logger.error("Unexpected exception: {}".format(e)) - logger.error("Exception traceback: {}".format(traceback.format_exc())) + async def run_task(alive_instances: List[str], task_name: str, *args, **kwargs): + tasks = [] + for instance_name in alive_instances: + llumlet_handle = self.instances[instance_name] + tasks.append(llumlet_handle.execute_engine_method.remote("_run_workers", task_name, *args, **kwargs)) + rets = await asyncio.gather(*tasks, return_exceptions=True) + dead_instances = set() + for instance_name, ret in zip(alive_instances, rets): + if isinstance(ret, ray.exceptions.RayActorError): + dead_instances.add(instance_name) + if len(dead_instances) > 0: + self.scale_down(dead_instances, rebuild_migration_backend=False) + clear_gloo_backend_state() + return dead_instances - async def _check_deployment_states_loop(self, interval: float) -> None: - async def watch_instance_deployment_states(instance_id: str): - # There might be some delays of calling _init_server_and_instance, so sleep first. - await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL) - wait_pending_instance_time = 0.0 - while True: - instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))]) - instance_pending_creation = len(instance_state) == 1 and instance_state[0]["state"] == "PENDING_CREATION" - if not instance_pending_creation: - break - await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL) - wait_pending_instance_time += WATCH_DEPLOYMENT_INTERVAL - if wait_pending_instance_time >= WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE: - break - pg_created, server_alive, instance_alive = self.launcher.get_instance_deployment_states(instance_id) - if pg_created and (not server_alive or not instance_alive): - logger.warning("Instance {} deployment states incorrect, states: (pg {}, server {}, instance {})" - .format(instance_id, pg_created, server_alive, instance_alive)) - self.scale_down(instance_id) + alive_instances = sorted(self.instances.keys()) + pending_task = self.pending_rebuild_migration_instances + group_name = None + clear_gloo_backend_state() - while True: - try: - curr_pgs, curr_servers, curr_instances = self.launcher.get_cluster_deployment() - assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances)) - tasks = [] - for instance_id in curr_pgs: - if instance_id not in curr_servers or instance_id not in curr_instances: - tasks.append(asyncio.create_task(watch_instance_deployment_states(instance_id))) - await asyncio.gather(*tasks, return_exceptions=True) - await asyncio.sleep(interval) - # pylint: disable=broad-except - except Exception as e: - logger.error("Unexpected exception: {}".format(e)) - logger.error("Exception traceback: {}".format(traceback.format_exc())) + while len(alive_instances) > 0 and self.pending_rebuild_migration_instances > 0: + dead_instances = set() + group_name = random_uuid() + instance_rank = {instance_id: index for index, instance_id in enumerate(alive_instances)} + dead_instances.update(await run_task(alive_instances, "rebuild_migration_backend", + instance_rank, group_name)) + if len(dead_instances) == 0 and self.pending_rebuild_migration_instances == pending_task: + dead_instances.update(await run_task(alive_instances, "warmup")) + if len(dead_instances) == 0: + self.pending_rebuild_migration_instances -= pending_task + alive_instances = sorted(set(self.instances.keys()) - dead_instances) + pending_task = self.pending_rebuild_migration_instances - async def is_ready(self) -> bool: - """Called by api server, return true when all the instances have been successfully created.""" - tasks = [instance.is_ready.remote() for instance in self.instances.values()] - is_ready_list = await asyncio.gather(*tasks, return_exceptions=True) - return all(is_ready_list) + if len(alive_instances) == 0: + self.pending_rebuild_migration_instances = 0 + group_name = None + + migration_filter: CustomFilter = self.global_scheduler.migration_scheduler\ + .migration_filter.get_filter("migration_backend_init_filter") + migration_filter.set_filter_condtition( + src_filter=lambda instance_info: instance_info.instance_id in alive_instances, + dst_filter=lambda instance_info: instance_info.instance_id in alive_instances) + + logger.info("Rebuild migration backend done, group_name: {}, alive instance ({}): {}." + .format(group_name, len(alive_instances), alive_instances)) + + # Restore migrate config + self.enable_migration = origin_config + + async def _connect_to_instances(self): + def connect_to_instances_done_callback(instance_id: str, instance_actor_handle: "ray.actor.ActorHandle", fut): + ret = fut.result()[0] + if not isinstance(ret, Exception): + scale_up_instance_ids.append(instance_id) + scale_up_instance_actor_handles.append(instance_actor_handle) + scale_up_instance_args.append(ret) + logger.info("Connect to instance {}".format(instance_id)) + else: + logger.warning("Connect to instance {} failed, exception: {}".format(instance_id, ret)) + + # Must set True despite set namespance to llumnix. + actor_names_dict = ray.util.list_named_actors(all_namespaces=True) + instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict + if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] + instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names] + scale_up_instance_ids = [] + scale_up_instance_actor_handles = [] + scale_up_instance_args = [] + tasks = [] + for instance_actor_name, instance_actor_handle in zip(instance_actor_names, instance_actor_handles): + instance_id = instance_actor_name[len('instance_'):] + if instance_id not in self.instances: + task = asyncio.gather(instance_actor_handle.get_instance_args.remote(), return_exceptions=True) + task.add_done_callback(partial(connect_to_instances_done_callback, instance_id, instance_actor_handle)) + tasks.append(task) + await asyncio.gather(*tasks) + # The only function that can add instance actor handles to manager. + self.scale_up(scale_up_instance_ids, scale_up_instance_actor_handles, scale_up_instance_args) async def _check_instance_error(self, migrate_instance_pairs: Tuple[str, str]) -> List[bool]: def check_instance_error_done_callback(idx: int, instance_id: str, fut): diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index aa4a9ff6..1fecbed4 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -93,8 +93,8 @@ def get_instance_num_blocks(): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -@pytest.mark.parametrize("migration_backend", ['rayrpc']) -@pytest.mark.parametrize("migrated_request_status", ['running']) +@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) +@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting']) async def test_migration_benchmark(ray_env, shutdown_llumnix_service, model, migration_backend, migrated_request_status): if migrated_request_status == 'waiting' and migration_backend != 'rayrpc': pytest.skip("When the migrated request status is waiting, only test the rayrpc migration backend.") diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 3611fe04..4380b598 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -120,21 +120,17 @@ def init_manager(): ray.get(manager.is_ready.remote()) return manager + class MockManager(Manager): - async def init_placement_group(self, *args, **kwargs): + def init_placement_group(self, *args, **kwargs): return self.launcher.init_placement_group(*args, **kwargs) - async def init_server_and_instance(self, *args, **kwargs): + def init_server_and_instance(self, *args, **kwargs): return self.launcher.init_server_and_instance(*args, **kwargs) - async def clear_instance_ray_resources(self, instance_id: str): + def clear_instance_ray_resources(self, instance_id: str): return self.launcher.clear_instance_ray_resources(instance_id) - async def get_cluster_deployment(self): - return self.launcher.get_cluster_deployment() - - async def get_instance_deployment_states(self, instance_id: str): - return self.launcher.get_instance_deployment_states(instance_id) def init_manager_with_launch_mode(launch_mode, request_output_queue_type="rayqueue", enable_pd_disagg=False, pd_ratio="1:3", max_instances=-1): @@ -351,11 +347,11 @@ def test_poll_instance_info_loop_and_migrate(ray_env, manager): async def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_instance_ray_resources(ray_env): manager, _, _, engine_args, _ = init_manager_with_launch_mode(LaunchMode.LOCAL) instance_id = random_uuid() - pg = await manager.init_placement_group(get_placement_group_name(instance_id), - engine_args, BackendType.VLLM, init_server=True) + pg = manager.init_placement_group(get_placement_group_name(instance_id), + engine_args, BackendType.VLLM, init_server=True) pg = ray.util.get_placement_group(get_placement_group_name(instance_id)) ray.get(pg.ready()) - await manager.init_server_and_instance(instance_id, EntrypointsArgs(), InstanceArgs(), engine_args, BackendType.VLLM, pg) + manager.init_server_and_instance(instance_id, EntrypointsArgs(), InstanceArgs(), engine_args, BackendType.VLLM, pg) # wait for scale up await asyncio.sleep(5.0) @@ -366,11 +362,11 @@ async def test_init_server_and_get_instance_deployment_states_and_instance_and_c num_instances = manager.scale_up(instance_id, instance, InstanceArgs()) assert num_instances == 1 - pg_created, server_alive, instance_alive = await manager.get_instance_deployment_states(instance_id) + pg_created, server_alive, instance_alive = manager._get_instance_deployment_states(instance_id) assert pg_created and server_alive and instance_alive # test clear_instance_ray_resources - await manager.clear_instance_ray_resources(instance_id) + manager.clear_instance_ray_resources(instance_id) # wait for remove and kill await asyncio.sleep(5.0) @@ -381,31 +377,31 @@ async def test_init_server_and_get_instance_deployment_states_and_instance_and_c instance_exists = is_actor_exists(get_instance_name(instance_id)) assert not instance_exists - pg_created, server_alive, instance_alive = await manager.get_instance_deployment_states(instance_id) + pg_created, server_alive, instance_alive = manager._get_instance_deployment_states(instance_id) assert not pg_created and not server_alive and not instance_alive @pytest.mark.asyncio @pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq']) -async def test_auto_scale_up_loop_and_get_cluster_deployment(ray_env, request_output_queue_type): +async def test_auto_scale_up_loop_and_get_cluster_deployment_states(ray_env, request_output_queue_type): manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type) await asyncio.sleep(60.0) num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 actor_names_dict = ray.util.list_named_actors(all_namespaces=True) instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)] assert len(instance_ids) == 4 - await manager.clear_instance_ray_resources(instance_ids[0]) - await manager.clear_instance_ray_resources(instance_ids[1]) + manager.clear_instance_ray_resources(instance_ids[0]) + manager.clear_instance_ray_resources(instance_ids[1]) await asyncio.sleep(60.0) num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 @pytest.mark.asyncio @@ -416,7 +412,7 @@ async def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, requ num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 actor_names_dict = ray.util.list_named_actors(all_namespaces=True) @@ -431,20 +427,20 @@ async def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, requ num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 def test_pd_disagg_gloal_launch_instance_type(): launcher = Launcher(None, True, False, True, False, [1, 2]) assert launcher._get_next_instance_type(0, 0, [1, 2]) == InstanceType.PREFILL - launcher.inflight_num_prefill += 1 + launcher.inflight_num_prefill_instance += 1 assert launcher._get_next_instance_type(0, 0, [1, 2]) == InstanceType.DECODE - launcher.inflight_num_decode += 1 + launcher.inflight_num_decode_instance += 1 - launcher.inflight_num_prefill = 0 - launcher.inflight_num_decode = 0 + launcher.inflight_num_prefill_instance = 0 + launcher.inflight_num_decode_instance = 0 assert launcher._get_next_instance_type(1, 1, [1, 2]) == InstanceType.DECODE assert launcher._get_next_instance_type(1, 2, [1, 2]) == InstanceType.PREFILL @@ -461,7 +457,7 @@ async def test_pd_disagg_gloal_launch_deployment_and_auto_scale_up_loop(ray_env, num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 num_prefill_instances = 0 @@ -493,7 +489,7 @@ async def test_pd_disagg_gloal_launch_deployment_and_auto_scale_up_loop(ray_env, num_instances = manager.scale_up([], [], []) assert num_instances == 4 - curr_pgs, curr_servers, curr_instances = await manager.get_cluster_deployment() + curr_pgs, curr_servers, curr_instances = manager._get_cluster_deployment_states() assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4 num_prefill_instances = 0 @@ -518,26 +514,26 @@ async def test_pd_disagg_deployment_states(): instance_args=InstanceArgs(migration_backend="rayrpc"), engine_args=engine_args, launch_args=LaunchArgs(LaunchMode.LOCAL, BackendType.VLLM), work_dir=os.getcwd()) - assert not manager._inner_check_pd_deployment() + assert not manager._check_pd_deployment_states() prefill_instance_ids = [random_uuid() for _ in range(3)] decode_instance_ids = [random_uuid() for _ in range(3)] manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) - assert manager._inner_check_pd_deployment() in prefill_instance_ids + assert manager._check_pd_deployment_states() in prefill_instance_ids manager.scale_down(prefill_instance_ids) manager.scale_up(decode_instance_ids, [None]*len(decode_instance_ids), [InstanceArgs(instance_type="decode")]*len(decode_instance_ids)) - assert manager._inner_check_pd_deployment() in decode_instance_ids + assert manager._check_pd_deployment_states() in decode_instance_ids manager.scale_up(prefill_instance_ids, [None]*len(prefill_instance_ids), [InstanceArgs(instance_type="prefill")]*len(prefill_instance_ids)) - assert not manager._inner_check_pd_deployment() + assert not manager._check_pd_deployment_states() @pytest.mark.asyncio -async def test_auto_scale_up_loop_max_instances(ray_env): +async def test_auto_scale_up_loop_max_instances(): manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, "rayqueue", max_instances=2) await asyncio.sleep(60.0) num_instances = manager.scale_up([], [], []) From 608f17b31c83812b7ae0c540051bf55e6d69f399 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 07:46:34 +0000 Subject: [PATCH 54/59] Fix lint --- llumnix/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/manager.py b/llumnix/manager.py index 4abab584..3daa2780 100644 --- a/llumnix/manager.py +++ b/llumnix/manager.py @@ -35,8 +35,8 @@ from llumnix.arg_utils import ManagerArgs, EntrypointsArgs, InstanceArgs, LaunchArgs from llumnix.server_info import ServerInfo from llumnix.backends.backend_interface import BackendType -from llumnix.utils import (random_uuid, clear_gloo_backend_state, get_server_name, - get_instance_name, get_manager_name, get_placement_group_name, +from llumnix.utils import (random_uuid, clear_gloo_backend_state, get_server_name, + get_instance_name, get_manager_name, get_placement_group_name, INSTANCE_NAME_PREFIX, SERVER_NAME_PREFIX, run_async_func_sync) from llumnix.entrypoints.utils import LaunchMode from llumnix.queue.queue_type import QueueType From a38a3740ce182c3e3592842610d4e7472ba02d1c Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 12 Feb 2025 08:28:30 +0000 Subject: [PATCH 55/59] Fix correctness test --- tests/e2e_test/test_correctness.py | 39 +++++++++++++++--------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/e2e_test/test_correctness.py b/tests/e2e_test/test_correctness.py index 8af96cf5..eb93b6db 100644 --- a/tests/e2e_test/test_correctness.py +++ b/tests/e2e_test/test_correctness.py @@ -94,30 +94,31 @@ async def test_correctness(ray_env, shutdown_llumnix_service, model, launch_mode if launch_mode == "local": if enable_pd_disagg: launch_commands.append(generate_launch_command(result_filename=str(base_port)+".out", - model=model, - max_model_len=max_model_len, - port=base_port, - enable_pd_disagg=enable_pd_disagg, - instance_type="prefill")) + model=model, + max_model_len=max_model_len, + ip=ip, + port=base_port, + enable_pd_disagg=enable_pd_disagg, + instance_type="prefill")) launch_commands.append(generate_launch_command(result_filename=str(base_port+1)+".out", - launch_ray_cluster=False, - model=model, - max_model_len=max_model_len, - ip=ip, - port=base_port+1, - enable_pd_disagg=enable_pd_disagg, - instance_type="decode")) + launch_ray_cluster=False, + model=model, + max_model_len=max_model_len, + ip=ip, + port=base_port+1, + enable_pd_disagg=enable_pd_disagg, + instance_type="decode")) else: launch_commands.append(generate_launch_command(model=model, - max_model_len=max_model_len, - ip=ip, - port=base_port)) + max_model_len=max_model_len, + ip=ip, + port=base_port)) else: launch_commands.append(generate_serve_command(result_filename=str(base_port)+".out", - ip=ip, - port=base_port, - model=model, - enable_pd_disagg=enable_pd_disagg)) + ip=ip, + port=base_port, + model=model, + enable_pd_disagg=enable_pd_disagg)) for launch_command in launch_commands: subprocess.run(launch_command, shell=True, check=True) await asyncio.sleep(3) From 58e26475cff6e70d3eca6a9c385823a150ea0abb Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 19 Feb 2025 06:42:06 +0000 Subject: [PATCH 56/59] Rename power-of-k-choice to topk-random-dispatch --- docs/Arguments.md | 6 +++--- llumnix/arg_utils.py | 8 ++++---- llumnix/config/default.py | 2 +- llumnix/global_scheduler/dispatch_policy.py | 20 +++++++++---------- .../global_scheduler/dispatch_scheduler.py | 6 +++--- llumnix/global_scheduler/global_scheduler.py | 2 +- llumnix/internal_config.py | 4 ++-- .../test_dispatch_scheduler.py | 8 ++++---- 8 files changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/Arguments.md b/docs/Arguments.md index e5c9e0bd..c2eda57e 100644 --- a/docs/Arguments.md +++ b/docs/Arguments.md @@ -25,7 +25,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] [--scaling-load-metric {remaining_steps,usage_ratio}] [--polling-interval POLLING_INTERVAL] [--dispatch-policy {balanced,load,queue,rr}] - [--power-of-k-choice POWER_OF_K_CHOICE] + [--topk-random-dispatch TOPK_RANDOM_DISPATCH] [--enable-migration] [--enable-defrag] [--pair-migration-frequency PAIR_MIGRATION_FREQUENCY] @@ -140,8 +140,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h] - Possible choices: balanced, load, queue, rr - Default: "load" -`--power-of-k-choice` -- Number of candidate instances for dispatch policy +`--topk-random-dispatch` +- Number of candidate random dispatch instances for dispatch policy. - Default: 1 `--enable-migration` diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index f819ac55..c82881d2 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -121,7 +121,7 @@ class ManagerArgs: polling_interval: float = None dispatch_policy: str = None scaling_load_metric: str = None - power_of_k_choice: int = None + topk_random_dispatch: int = None enable_migration: bool = None pair_migration_frequency: int = None @@ -177,7 +177,7 @@ def create_global_scheduler_config(self, is_group_kind_migration_backend) -> Tup # Create the GlobalScheduler Configuration. global_scheduler_config = GlobalSchedulerConfig(self.initial_instances, self.dispatch_policy, - self.power_of_k_choice, + self.topk_random_dispatch, self.pair_migration_policy, self.migrate_out_threshold, self.scaling_policy, @@ -232,9 +232,9 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: '* "queue" dispatch request to the instance with minimum waiting request queue length.\n' '* "flood" dispatch request to the instance with maximum requests dispatched.\n' '* "rr" dispatch requests with round-robin policy.\n') - parser.add_argument('--power-of-k-choice', + parser.add_argument('--topk-random-dispatch', type=int, - help='number of candidate instances for dispatch policy.\n\n' + help='number of candidate random dispatch instances for dispatch policy.\n\n' 'The candidate instances are first selected according to the load' '(including factors such as load, queue size, etc.) based on the dispatch policy,' 'and then one of them is randomly chosen to receive the request for better load balancing.') diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 78be7e04..8b3cc231 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -80,7 +80,7 @@ # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' # Number of candidate instances for dispatch policy -_C.MANAGER.POWER_OF_K_CHOICE = 1 +_C.MANAGER.TOPK_RANDOM_DISPATCH = 1 # -------------------------- MIGRATION CONFIGURATION -------------------------- # Enable migrate requests between instances diff --git a/llumnix/global_scheduler/dispatch_policy.py b/llumnix/global_scheduler/dispatch_policy.py index 1d594b00..02723981 100644 --- a/llumnix/global_scheduler/dispatch_policy.py +++ b/llumnix/global_scheduler/dispatch_policy.py @@ -19,8 +19,8 @@ def sort_instance_infos(available_instance_infos: List[InstanceInfo], ) def random_choice_from_top_k(sorted_instance_infos: List[InstanceInfo], - power_of_k_choice: int): - k = min(power_of_k_choice, len(sorted_instance_infos)) + topk_random_dispatch: int): + k = min(topk_random_dispatch, len(sorted_instance_infos)) top_k_instance_infos = sorted_instance_infos[:k] return random.choice(top_k_instance_infos) @@ -30,7 +30,7 @@ class DispatchPolicy(ABC): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> int: + topk_random_dispatch: int) -> int: pass @@ -39,7 +39,7 @@ class Flood(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> str: + topk_random_dispatch: int) -> str: instance_id = max(instance_num_requests, key=instance_num_requests.get) return instance_id @@ -48,7 +48,7 @@ class Balanced(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> str: + topk_random_dispatch: int) -> str: # dispatch request according to the number of requests dispatched to instance by manager instance_id = min(instance_num_requests, key=instance_num_requests.get) return instance_id @@ -58,9 +58,9 @@ class Load(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> str: + topk_random_dispatch: int) -> str: sorted_instance_infos = sort_instance_infos(available_instance_infos, 'dispatch_load_metric') - instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, power_of_k_choice) + instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, topk_random_dispatch) instance_id = instance_info_chosen.instance_id logger.info("dispatch to {}, load: {}".format(instance_id, instance_info_chosen.dispatch_load_metric)) return instance_id @@ -70,9 +70,9 @@ class Queue(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> str: + topk_random_dispatch: int) -> str: sorted_instance_infos = sort_instance_infos(available_instance_infos, 'num_waiting_requests') - instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, power_of_k_choice) + instance_info_chosen = random_choice_from_top_k(sorted_instance_infos, topk_random_dispatch) instance_id = instance_info_chosen.instance_id logger.info("dispatch to {}, queue size: {}".format(instance_id, instance_info_chosen.num_waiting_requests)) return instance_id @@ -84,7 +84,7 @@ class RoundRobin(DispatchPolicy): def dispatch(self, instance_num_requests: Dict[str, int], available_instance_infos: List[InstanceInfo], - power_of_k_choice: int) -> str: + topk_random_dispatch: int) -> str: all_instance_ids = sorted(instance_num_requests.keys()) cur_instance_idx = (self.prev_instance_idx + 1) % len(all_instance_ids) target_instance_id = all_instance_ids[cur_instance_idx] diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index e335f127..02ca8b96 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -25,9 +25,9 @@ class DispatchScheduler: def __init__(self, dispatch_policy: str, - power_of_k_choice: int) -> None: + topk_random_dispatch: int) -> None: self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy) - self.power_of_k_choice = power_of_k_choice + self.topk_random_dispatch = topk_random_dispatch self.available_dispatch_instance_set: Set[str] = set() self.instance_info: Dict[str, InstanceInfo] = {} # statistics @@ -38,7 +38,7 @@ def dispatch(self) -> str: self.total_num_requests += 1 dispatch_instance_id = self.dispatch_policy.dispatch(self.instance_num_requests, self.instance_info.values(), - self.power_of_k_choice) + self.topk_random_dispatch) self.instance_num_requests[dispatch_instance_id] += 1 if self.total_num_requests % DISPATCH_LOG_FREQUENCY == 0: logger.info("dispatch scheduler total_dispatched_requests: {}".format(self.total_num_requests)) diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index 9f049a47..3acf144e 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -35,7 +35,7 @@ def __init__(self, global_scheduler_config: GlobalSchedulerConfig) -> None: # dispatch args self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy, - global_scheduler_config.power_of_k_choice) + global_scheduler_config.topk_random_dispatch) # migrate args self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy, global_scheduler_config.migrate_out_load_threshold, diff --git a/llumnix/internal_config.py b/llumnix/internal_config.py index 4b02f8db..d8c79f0a 100644 --- a/llumnix/internal_config.py +++ b/llumnix/internal_config.py @@ -16,7 +16,7 @@ def __init__( self, initial_instances: int, dispatch_policy: str, - power_of_k_choice: int, + topk_random_dispatch: int, pair_migration_policy: str, migrate_out_threshold: float, scaling_policy: str, @@ -27,7 +27,7 @@ def __init__( is_group_kind_migration_backend: bool) -> None: self.initial_instances = initial_instances self.dispatch_policy = dispatch_policy - self.power_of_k_choice = power_of_k_choice + self.topk_random_dispatch = topk_random_dispatch self.pair_migration_policy = pair_migration_policy self.migrate_out_load_threshold = migrate_out_threshold diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 37ab241a..320d5fea 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -155,11 +155,11 @@ def test_dispatch_rr(): target_instance_id = idx%instance_num assert instance_id == f'instance_{target_instance_id}' -def test_dispatch_power_of_k_choice(): +def test_dispatch_topk_random_dispatch(): num_tests = 100 instance_num = 4 - for power_of_k_choice in [1, 2, 3]: - dispatch_scheduler = DispatchScheduler('load', power_of_k_choice) + for topk_random_dispatch in [1, 2, 3]: + dispatch_scheduler = DispatchScheduler('load', topk_random_dispatch) instance_num_requests = {} instance_info_dict = {} for instance_id in [f'instance_{i}' for i in range(1, instance_num + 1)]: @@ -174,4 +174,4 @@ def test_dispatch_power_of_k_choice(): instance_id_set = set() for _ in range(num_tests): instance_id_set.add(dispatch_scheduler.dispatch()) - assert len(instance_id_set) == power_of_k_choice + assert len(instance_id_set) == topk_random_dispatch From 23510f223bd7c91a7531dc033f6bafc3d8b675f1 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 19 Feb 2025 08:30:08 +0000 Subject: [PATCH 57/59] Fix proxy_actor --- llumnix/backends/vllm/migration_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index acbac62e..3b38f91f 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -24,7 +24,7 @@ logger = init_logger(__name__) -@ray.remote(num_cpus=1, max_concurrency=2) +@ray.remote(num_cpus=0, max_concurrency=2) class ProxyActor: def exec_method(self, is_driver_worker, handle, *args, **kwargs): try: From 876d8622261947251223ffb699b62cfc5961c7d9 Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 19 Feb 2025 08:30:13 +0000 Subject: [PATCH 58/59] Fix ray_env --- tests/conftest.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 54947846..ff6823e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,9 +28,6 @@ from llumnix.utils import random_uuid -def pytest_sessionstart(session): - subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=False, - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) def cleanup_ray_env_func(): try: @@ -70,9 +67,15 @@ def cleanup_ray_env_func(): @pytest.fixture def ray_env(): + subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=False, + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + time.sleep(3.0) ray.init(namespace="llumnix", ignore_reinit_error=True) yield cleanup_ray_env_func() + time.sleep(1.0) + subprocess.run(["ray", "stop"], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + time.sleep(3.0) def backup_error_log(func_name): curr_time = datetime.now().strftime('%Y_%m_%d_%H_%M_%S') From 9b3fdd3e7db16fe59ebf4e848639d6ee52bb9f9a Mon Sep 17 00:00:00 2001 From: s5u13b Date: Wed, 19 Feb 2025 08:43:14 +0000 Subject: [PATCH 59/59] Fix import time in conftest --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index ff6823e2..98b1f462 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ # limitations under the License. from datetime import datetime +import time import shutil import os import subprocess