Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Jan 16, 2025
1 parent 88c3102 commit 40e9125
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def _check_engine_state_loop(self):
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
await asyncio.sleep(0)
self_actor = ray.get_actor(self.actor_name)
self_actor = ray.get_actor(name=self.actor_name, namespace="llumnix")
ray.kill(self_actor)

async def migrate_out(self, dst_instance_name: str) -> List[str]:
Expand Down
19 changes: 17 additions & 2 deletions tests/unit_test/llumlet/test_engine_step_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import time
import ray
import torch
Expand All @@ -35,14 +36,27 @@ def __init__(self, *args, **kwargs) -> None:
kwargs["placement_group"] = placement_group
super().__init__(*args, **kwargs)
self.origin_step = self.backend_engine.engine.step_async

def set_error_step(self, broken: bool):
self.backend_engine._stop_event.set()

async def raise_error_step():
await self.origin_step()
raise ValueError("Mock engine step error")
self.backend_engine.engine.step_async = raise_error_step

if broken:
self.backend_engine.engine.step_async = raise_error_step
else:
self.backend_engine.engine.step_async = self.origin_step

asyncio.create_task(self.backend_engine._start_engine_step_loop())

@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need at least 1 GPU to run the test.")
def test_engine_step_exception(ray_env):
engine_args = EngineArgs(model="facebook/opt-125m", max_model_len=8, worker_use_ray=True)

# wait previous test to release the GPU memory
time.sleep(5.0)
device_count = torch.cuda.device_count()
origin_free_memory_list = []
for device_id in range(device_count):
Expand All @@ -69,7 +83,8 @@ def test_engine_step_exception(ray_env):
cur_free_memory_list.append(cur_free_memory)
assert origin_free_memory_list != cur_free_memory_list

time.sleep(5)
ray.get(llumlet.set_error_step.remote(True))
time.sleep(3)

all_actors = ray.util.list_named_actors(True)
all_actor_names = [actor["name"] for actor in all_actors]
Expand Down

0 comments on commit 40e9125

Please sign in to comment.