From e5ee36dab32477daa0b502227fb0c9a001f5e5ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Berland?= Date: Fri, 5 Jan 2024 09:33:52 +0100 Subject: [PATCH] Support max_running in Scheduler --- src/ert/scheduler/scheduler.py | 2 +- tests/unit_tests/scheduler/test_scheduler.py | 41 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 4036e0e0ecf..7c134faffa8 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -159,7 +159,7 @@ async def execute( cancel_when_execute_is_done(self._update_avg_job_runtime()) start = asyncio.Event() - sem = asyncio.BoundedSemaphore(self._max_running) + sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs)) for iens, job in self._jobs.items(): self._tasks[iens] = asyncio.create_task( job(start, sem, self._max_submit) diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 6342290063e..788c2a65fd5 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -2,6 +2,7 @@ import json import shutil from pathlib import Path +from typing import List import pytest from cloudevents.http import CloudEvent, from_json @@ -207,6 +208,46 @@ async def wait(): assert timeouteventfound +@pytest.mark.parametrize("max_running", [0, 1, 2, 10]) +async def test_max_running(max_running, mock_driver, storage, tmp_path): + runs: List[bool] = [] + + async def wait(): + nonlocal runs + runs.append(True) + await asyncio.sleep(0.01) + runs.append(False) + + # Ensemble size must be larger than max_running to be able + # to expose issues related to max_running + ensemble_size = max_running * 3 if max_running > 0 else 10 + + ensemble = storage.create_experiment().create_ensemble( + name="foo", ensemble_size=ensemble_size + ) + realizations = [ + create_stub_realization(ensemble, tmp_path, iens) + for iens in range(ensemble_size) + ] + + sch = scheduler.Scheduler( + mock_driver(wait=wait), realizations, max_running=max_running + ) + + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED + + currently_running = 0 + max_running_observed = 0 + for run in runs: + currently_running += 1 if run else -1 + max_running_observed = max(max_running_observed, currently_running) + + if max_running > 0: + assert max_running_observed == max_running + else: + assert max_running_observed == ensemble_size + + @pytest.mark.timeout(6) async def test_max_runtime_while_killing(realization, mock_driver): wait_started = asyncio.Event()