From c12aedc0d2d6b60186015dc88091baafb2698503 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Wed, 3 Jan 2024 10:35:29 -0700 Subject: [PATCH] Fix #6608 normalized job_driver's agent life cycle code (#6617) - Changed some conditionals to assertions to simplify code - Locking around start and end of the agent life (`_agent_life_change_lock`) - free resources is async so it can use `_agent_life_change_lock` - Ops are bound to driver in prepare_send (`driver._prepared_sends`), was too early in agent_start - Fix #6613 `_agent_start_delay` is local to op - Fix #6572 added more logging, fixed some messages, and normalized others - `websocket_ready_timeout` handled correctly. `_agent_receive_alive` calls `_websocket_ready_timeout_cancel` instead of `_agent_starting_done`, which has been removed. This was causing `idle_timeout` to be canceled and restarted, when `idle_timeout` should only be started once the agent is ready. - Cleaned up some naming (opId > op_id when not in msg) - Cascade exceptions into op.internal_error for better logging - Improved error handling/logging in sirepo.status - Replaced Awaited exception with enum return job_supervisor.SlotAllocStatus. - Changed `job_driver._slots_ready` to allocate all required slots instead of returning after first await. Slots are not deallocated so we do not have to validate them after each retry. The only place they are freed is in `destroy_op` which raises CancelTask, bypassing all that logic. - Removed MAX_RETRIES, because unnecessary after `_slots_ready` simplification. There is a single retry in `job_driver.prepare_send` to check `_agent_ready` if `_slots_ready` returns `HAD_TO_AWAIT`. The purpose of the Awaited exception was simply to ensure the agent is ready before `send` is called. The code is now much simpler and accomplishes the same thing. - refactored some asserts to be AssertionError --- sirepo/job_driver/__init__.py | 291 +++++++++++++++++++--------------- sirepo/job_driver/docker.py | 9 +- sirepo/job_driver/local.py | 25 +-- sirepo/job_driver/sbatch.py | 26 +-- sirepo/job_supervisor.py | 231 ++++++++++++++------------- sirepo/pkcli/job_agent.py | 4 +- sirepo/status.py | 34 ++-- tests/job_timeout_test.py | 2 - 8 files changed, 339 insertions(+), 283 deletions(-) diff --git a/sirepo/job_driver/__init__.py b/sirepo/job_driver/__init__.py index ef4c726e2b..6b020282af 100644 --- a/sirepo/job_driver/__init__.py +++ b/sirepo/job_driver/__init__.py @@ -60,12 +60,10 @@ def assign_instance_op(op): res = _CLASSES[job.SBATCH].get_instance(op) else: res = _DEFAULT_CLASS.get_instance(op) - assert m.uid == res.uid, "op.msg.uid={} is not same as db.uid={} for jid={}".format( - m.uid, - res.uid, - m.get("computeJid"), - ) - res.ops[op.opId] = op + if m.uid != res.uid: + raise AssertionError( + f"op.msg.uid={m.uid} is not same as db.uid={res.uid} for jid={m.get('computeJid')}", + ) return res @@ -78,18 +76,18 @@ def __init__(self, op): super().__init__( driver_details=PKDict({"type": self.__class__.__name__}), kind=op.kind, - ops=PKDict(), # TODO(robnagler) sbatch could override OP_RUN, but not OP_ANALYSIS # because OP_ANALYSIS touches the directory sometimes. Reasonably # there should only be one OP_ANALYSIS running on an agent at one time. op_slot_q=PKDict({k: job_supervisor.SlotQueue() for k in SLOT_OPS}), uid=op.msg.uid, - _agentId=job.unique_key(), - _agent_start_lock=tornado.locks.Lock(), - _agent_starting_timeout=None, + _agent_id=job.unique_key(), + _agent_life_change_lock=tornado.locks.Lock(), _idle_timer=None, + _prepared_sends=PKDict(), _websocket=None, _websocket_ready=sirepo.tornado.Event(), + _websocket_ready_timeout=None, ) self._sim_db_file_token = sirepo.sim_db_file.FileReq.token_for_user(self.uid) self._global_resources_token = ( @@ -99,33 +97,35 @@ def __init__(self, op): else None ) # Drivers persist for the life of the program so they are never removed - self.__instances[self._agentId] = self + self.__instances[self._agent_id] = self pkdlog("{}", self) def destroy_op(self, op): """Clear our op and (possibly) free cpu slot""" - self.ops.pkdel(op.opId) - op.cpu_slot.free() - if op.op_slot: - op.op_slot.free() + self._prepared_sends.pkdel(op.op_id) + if x := op.pkdel("cpu_slot"): + x.free() + if x := op.pkdel("op_slot"): + x.free() if "lib_dir_symlink" in op: # lib_dir_symlink is unique_key so not dangerous to remove pykern.pkio.unchecked_remove(op.pkdel("lib_dir_symlink")) - def free_resources(self, internal_error=None): + async def free_resources(self, caller): """Remove holds on all resources and remove self from data structures""" - pkdlog("{} internal_error={}", self, internal_error) try: - self._agent_starting_done() - self._websocket_ready.clear() - w = self._websocket - self._websocket = None - if w: - # Will not call websocket_on_close() - w.sr_close() - for o in list(self.ops.values()): - o.destroy(internal_error=internal_error) - self._websocket_free() + async with self._agent_life_change_lock: + await self.kill() + self._websocket_ready_timeout_cancel() + self._websocket_ready.clear() + w = self._websocket + self._websocket = None + if w: + # Will not call websocket_on_close() + w.sr_close() + e = f"job_driver.free_resources caller={caller}" + for o in list(self._prepared_sends.values()): + o.destroy(cancel_task=True, internal_error=e) except Exception as e: pkdlog("{} error={} stack={}", self, e, pkdexc()) @@ -156,33 +156,39 @@ def make_lib_dir_symlink(self, op): ) def op_is_untimed(self, op): - return op.opName in _UNTIMED_OPS + return op.op_name in _UNTIMED_OPS def pkdebug_str(self): return pkdformat( "{}(a={:.4} k={} u={:.4} {})", self.__class__.__name__, - self._agentId, + self._agent_id, self.kind, self.uid, - list(self.ops.values()), + list(self._prepared_sends.values()), ) async def prepare_send(self, op): - """Sends the op + """Awaits agent ready and slots for sending. - Returns: - bool: True if the op was actually sent + Agent is guaranteed to be ready and all slots are allocated + upon return. """ + + # If the agent is not ready after awaiting on slots, we need + # to recheck the agent, because agent can die (asynchronously) at any point + # while waiting for slots. await self._agent_ready(op) - await self._slots_ready(op) + if await self._slots_ready(op) == job_supervisor.SlotAllocStatus.HAD_TO_AWAIT: + await self._agent_ready(op) + self._prepared_sends[op.op_id] = op @classmethod def receive(cls, msg): """Receive message from agent""" a = cls.__instances.get(msg.content.agentId) if a: - a._receive(msg) + a._agent_receive(msg) return pkdlog("unknown agent, sending kill; msg={}", msg) try: @@ -200,29 +206,38 @@ async def terminate(cls): try: # TODO(robnagler) need timeout await d.kill() - except job_supervisor.Awaited: - pass except Exception as e: # If one kill fails still try to kill the rest pkdlog("error={} stack={}", e, pkdexc()) def websocket_on_close(self): pkdlog("{}", self) - self.free_resources() + self._start_free_resources(caller="websocket_on_close") + + def _websocket_ready_timeout_cancel(self): + if self._websocket_ready_timeout: + tornado.ioloop.IOLoop.current().remove_timeout( + self._websocket_ready_timeout + ) + self._websocket_ready_timeout = None + + async def _websocket_ready_timeout_handler(self): + pkdlog("{} timeout={}", self, self.cfg.agent_starting_secs) + await self._start_free_resources(caller="_websocket_ready_timeout_handler") - def _agent_cmd_stdin_env(self, **kwargs): + def _agent_cmd_stdin_env(self, op, **kwargs): return job.agent_cmd_stdin_env( ("sirepo", "job_agent", "start"), - env=self._agent_env(), + env=self._agent_env(op), uid=self.uid, **kwargs, ) - def _agent_env(self, env=None): + def _agent_env(self, op, env=None): return job.agent_env( env=(env or PKDict()).pksetdefault( PYKERN_PKDEBUG_WANT_PID_TIME="1", - SIREPO_PKCLI_JOB_AGENT_AGENT_ID=self._agentId, + SIREPO_PKCLI_JOB_AGENT_AGENT_ID=self._agent_id, # POSIT: same as pkcli.job_agent.start SIREPO_PKCLI_JOB_AGENT_DEV_SOURCE_DIRS=os.environ.get( "SIREPO_PKCLI_JOB_AGENT_DEV_SOURCE_DIRS", @@ -230,7 +245,7 @@ def _agent_env(self, env=None): ), SIREPO_PKCLI_JOB_AGENT_SUPERVISOR_GLOBAL_RESOURCES_TOKEN=self._global_resources_token, SIREPO_PKCLI_JOB_AGENT_SUPERVISOR_GLOBAL_RESOURCES_URI=f"{self.cfg.supervisor_uri}{job.GLOBAL_RESOURCES_URI}", - SIREPO_PKCLI_JOB_AGENT_START_DELAY=self.get("_agent_start_delay", "0"), + SIREPO_PKCLI_JOB_AGENT_START_DELAY=str(op.get("_agent_start_delay", 0)), SIREPO_PKCLI_JOB_AGENT_SUPERVISOR_SIM_DB_FILE_TOKEN=self._sim_db_file_token, SIREPO_PKCLI_JOB_AGENT_SUPERVISOR_SIM_DB_FILE_URI=job.supervisor_file_uri( self.cfg.supervisor_uri, @@ -248,66 +263,18 @@ def _agent_env(self, env=None): uid=self.uid, ) + def _agent_is_idle(self): + return not self._prepared_sends and not self._websocket_ready_timeout + async def _agent_ready(self, op): if self._websocket_ready.is_set(): return await self._agent_start(op) pkdlog("{} {} await _websocket_ready", self, op) await self._websocket_ready.wait() - pkdc("{} websocket alive", op) - raise job_supervisor.Awaited() - - async def _agent_start(self, op): - if self._agent_starting_timeout: - return - async with self._agent_start_lock: - # POSIT: we do not have to raise Awaited(), because - # this is the first thing an op waits on. - if self._agent_starting_timeout or self._websocket_ready.is_set(): - return - try: - t = self.cfg.agent_starting_secs - if pkconfig.channel_in_internal_test(): - x = op.msg.pkunchecked_nested_get("data.models.dog.favoriteTreat") - if x: - x = re.search(r"agent_start_delay=(\d+)", x) - if x: - self._agent_start_delay = int(x.group(1)) - t += self._agent_start_delay - pkdlog( - "op={} agent_start_delay={}", - op, - self._agent_start_delay, - ) - pkdlog("{} {} await _do_agent_start", self, op) - # All awaits must be after this. If a call hangs the timeout - # handler will cancel this task - self._agent_starting_timeout = ( - tornado.ioloop.IOLoop.current().call_later( - t, - self._agent_starting_timeout_handler, - ) - ) - # POSIT: Canceled errors aren't smothered by any of the below calls - await self.kill() - await self._do_agent_start(op) - except (Exception, sirepo.const.ASYNC_CANCELED_ERROR) as e: - pkdlog("{} error={} stack={}", self, e, pkdexc()) - self.free_resources(internal_error="failure starting agent") - raise - - def _agent_starting_done(self): - self._start_idle_timeout() - if self._agent_starting_timeout: - tornado.ioloop.IOLoop.current().remove_timeout(self._agent_starting_timeout) - self._agent_starting_timeout = None - - async def _agent_starting_timeout_handler(self): - pkdlog("{} timeout={}", self, self.cfg.agent_starting_secs) - await self.kill() - self.free_resources(internal_error="timeout waiting for agent to start") + pkdlog("{} {} websocket alive", self, op) - def _receive(self, msg): + def _agent_receive(self, msg): c = msg.content i = c.get("opId") if ("opName" not in c or c.opName == job.OP_ERROR) or ( @@ -323,64 +290,139 @@ def _receive(self, msg): if "reply" not in c: pkdlog("{} no reply={}", self, c) c.reply = PKDict(state="error", error="no reply") - if i in self.ops: + if i in self._prepared_sends: # SECURITY: only ops known to this driver can be replied to - self.ops[i].reply_put(c.reply) + self._prepared_sends[i].reply_put(c.reply) else: pkdlog( - "{} not pending opName={} o={:.4}", + "{} not in prepared_sends opName={} o={:.4} content={}", self, - i, c.opName, + i, + c, ) else: - getattr(self, "_receive_" + c.opName)(msg) + getattr(self, "_agent_receive_" + c.opName)(msg) - def _receive_alive(self, msg): + def _agent_receive_alive(self, msg): """Receive an ALIVE message from our agent Save the websocket and register self with the websocket """ - self._agent_starting_done() + self._websocket_ready_timeout_cancel() if self._websocket: if self._websocket != msg.handler: - pkdlog("{} new websocket", self) - # New _websocket so bind - self.free_resources() - self._websocket = msg.handler + raise AssertionError(f"incoming msg.content={msg.content}") + else: + self._websocket = msg.handler self._websocket_ready.set() self._websocket.sr_driver_set(self) + self._start_idle_timeout() - def __str__(self): - return f"{type(self).__name__}({self._agentId:.4}, {self.uid:.4}, ops={list(self.ops.values())})" - - def _receive_error(self, msg): + def _agent_receive_error(self, msg): # TODO(robnagler) what does this mean? Just a way of logging? Document this. pkdlog("{} msg={}", self, msg) + async def _agent_start(self, op): + if self._websocket_ready_timeout: + return + try: + async with self._agent_life_change_lock: + if self._websocket_ready_timeout or self._websocket_ready.is_set(): + return + pkdlog("{} {} await=_do_agent_start", self, op) + # All awaits must be after this. If a call hangs the timeout + # handler will cancel this task + self._websocket_ready_timeout = ( + tornado.ioloop.IOLoop.current().call_later( + self._agent_start_delay(op), + self._websocket_ready_timeout_handler, + ) + ) + # POSIT: Canceled errors aren't smothered by any of the below calls + await self._do_agent_start(op) + except (Exception, sirepo.const.ASYNC_CANCELED_ERROR) as e: + pkdlog("{} error={} stack={}", self, e, pkdexc()) + self._start_free_resources(caller="_agent_start") + raise + + def _agent_start_delay(self, op): + t = self.cfg.agent_starting_secs + if not pkconfig.channel_in_internal_test(): + return t + x = op.pkunchecked_nested_get("msg.data.models.dog.favoriteTreat") + if not x: + return t + x = re.search(r"agent_start_delay=(\d+)", x) + if not x: + return t + op._agent_start_delay = int(x.group(1)) + pkdlog("op={} agent_start_delay={}", op, self._agent_start_delay) + return t + op._agent_start_delay + + def __str__(self): + return f"{type(self).__name__}({self._agent_id:.4}, {self.uid:.4}, ops={list(self._prepared_sends.values())})" + async def _slots_ready(self, op): - """Only one op of each type allowed""" - n = op.opName + """Allocate all required slots for op + + Slot allocation may require yielding so `_agent_ready` needs + to be called if `HAD_TO_AWAIT` is true. + + All slots are allocated and only freed when the op is + destroyed. We don't need to recheck the slots, because + `destroy_op` cancels this task. `_agent_ready` is state held + outside this op so it needs to be rechecked when + `HAD_TO_AWAIT` is returned. + + Return: + job_supervisor.SlotAllocStatus: whether coroutine had to yield + """ + + def _alloc_check(alloc_res): + if alloc_res == job_supervisor.SlotAllocStatus.HAD_TO_AWAIT: + return job_supervisor.SlotAllocStatus.DID_NOT_AWAIT + return res + + n = op.op_name + res = job_supervisor.SlotAllocStatus.DID_NOT_AWAIT if n in (job.OP_CANCEL, job.OP_KILL, job.OP_BEGIN_SESSION): - return + return res if n == job.OP_SBATCH_LOGIN: - l = [o for o in self.ops.values() if o.opId != op.opId] - assert not l, "received {} but have other ops={}".format(op, l) - return - await op.op_slot.alloc("Waiting for another simulation to complete") - await op.run_dir_slot.alloc("Waiting for access to simulation state") + if self._prepared_sends: + raise AssertionError( + f"received op={op} but have _prepared_sends={self._prepared_sends}", + ) + return res + res = _alloc_check( + await op.op_slot.alloc( + "Waiting for another simulation to complete await=op_slot" + ), + ) + res = _alloc_check( + await op.run_dir_slot.alloc( + "Waiting for access to simulation state await=run_dir_slot" + ), + ) if n not in _CPU_SLOT_OPS: - return + return res # once job-op relative resources are acquired, ask for global resources # so we only acquire on global resources, once we know we are ready to go. - await op.cpu_slot.alloc("Waiting for CPU resources") + res = _alloc_check( + await op.cpu_slot.alloc("Waiting for CPU resources await=cpu_slot"), + ) + return res + + def _start_free_resources(self, caller): + pkdlog("{} caller={}", self, caller) + tornado.ioloop.IOLoop.current().add_callback(self.free_resources, caller=caller) def _start_idle_timeout(self): async def _kill_if_idle(): self._idle_timer = None - if not self.ops: + if self._agent_is_idle(): pkdlog("{}", self) - await self.kill() + self._start_free_resources(caller="_kill_if_idle") else: self._start_idle_timeout() @@ -390,9 +432,6 @@ async def _kill_if_idle(): _kill_if_idle, ) - def _websocket_free(self): - pass - def init_module(**imports): global _cfg, _CLASSES, _DEFAULT_CLASS diff --git a/sirepo/job_driver/docker.py b/sirepo/job_driver/docker.py index 80157e95f5..2187501c6a 100644 --- a/sirepo/job_driver/docker.py +++ b/sirepo/job_driver/docker.py @@ -117,9 +117,10 @@ def init_class(cls, job_supervisor): return cls async def kill(self): - c = self.pkdel("_cid") - pkdlog("{} cid={:.12}", self, c) + c = None try: + c = self.pkdel("_cid") + pkdlog("{} cid={:.12}", self, c) # TODO(e-carlin): This can possibly hang and needs to be handled # Ex. docker daemon is not responsive await self._cmd( @@ -132,7 +133,7 @@ async def kill(self): pkdlog("{} error={} stack={}", self, e, pkdexc()) async def prepare_send(self, op): - if op.opName == job.OP_RUN: + if op.op_name == job.OP_RUN: op.msg.mpiCores = self.cfg[self.kind].get("cores", 1) return await super().prepare_send(op) @@ -169,7 +170,7 @@ def _constrain_resources(self, cfg_kind): ) async def _do_agent_start(self, op): - cmd, stdin, env = self._agent_cmd_stdin_env(cwd=self._agent_exec_dir) + cmd, stdin, env = self._agent_cmd_stdin_env(op, cwd=self._agent_exec_dir) pkdlog("{} agent_exec_dir={}", self, self._agent_exec_dir) pkio.mkdir_parent(self._agent_exec_dir) c = self.cfg[self.kind] diff --git a/sirepo/job_driver/local.py b/sirepo/job_driver/local.py index e8d22bfaa6..4ca7a4a918 100644 --- a/sirepo/job_driver/local.py +++ b/sirepo/job_driver/local.py @@ -29,7 +29,7 @@ def __init__(self, op): self.update( _agent_exec_dir=pkio.py_path(op.msg.userDir).join( "agent-local", - self._agentId, + self._agent_id, ), _agent_exit=tornado.locks.Event(), ) @@ -83,17 +83,20 @@ def init_class(cls, job_supervisor): async def kill(self): if "subprocess" not in self: return - pkdlog("{} pid={}", self, self.subprocess.proc.pid) - self.subprocess.proc.terminate() - self.kill_timeout = tornado.ioloop.IOLoop.current().call_later( - job_driver.KILL_TIMEOUT_SECS, - self.subprocess.proc.kill, - ) - await self._agent_exit.wait() - self._agent_exit.clear() + try: + pkdlog("{} pid={}", self, self.subprocess.proc.pid) + self.subprocess.proc.terminate() + self.kill_timeout = tornado.ioloop.IOLoop.current().call_later( + job_driver.KILL_TIMEOUT_SECS, + self.subprocess.proc.kill, + ) + await self._agent_exit.wait() + self._agent_exit.clear() + except Exception as e: + pkdlog("{} error={} stack={}", self, e, pkdexc()) async def prepare_send(self, op): - if op.opName == job.OP_RUN: + if op.op_name == job.OP_RUN: op.msg.mpiCores = sirepo.mpi.cfg().cores if op.msg.isParallel else 1 return await super().prepare_send(op) @@ -109,7 +112,7 @@ def _agent_on_exit(self, returncode): async def _do_agent_start(self, op): stdin = None try: - cmd, stdin, env = self._agent_cmd_stdin_env(cwd=self._agent_exec_dir) + cmd, stdin, env = self._agent_cmd_stdin_env(op, cwd=self._agent_exec_dir) pkdlog("{} agent_exec_dir={}", self, self._agent_exec_dir) # since this is local, we can make the directory; useful for debugging pkio.mkdir_parent(self._agent_exec_dir) diff --git a/sirepo/job_driver/sbatch.py b/sirepo/job_driver/sbatch.py index a100e3b831..66be2d2b8a 100644 --- a/sirepo/job_driver/sbatch.py +++ b/sirepo/job_driver/sbatch.py @@ -54,12 +54,15 @@ def _op_queue_size(op_kind): self.__instances[self.uid] = self async def kill(self): - if not self._websocket: + if not self.get("_websocket"): # if there is no websocket then we don't know about the agent # so we can't do anything return - # hopefully the agent is nice and listens to the kill - self._websocket.write_message(PKDict(opName=job.OP_KILL)) + try: + # hopefully the agent is nice and listens to the kill + self._websocket.write_message(PKDict(opName=job.OP_KILL)) + except Exception as e: + pkdlog("{} error={} stack={}", self, e, pkdexc()) @classmethod def get_instance(cls, op): @@ -121,7 +124,7 @@ async def prepare_send(self, op): ) ) m.runDir = "/".join((m.userDir, m.simulationType, m.computeJid)) - if op.opName == job.OP_RUN: + if op.op_name == job.OP_RUN: assert m.sbatchHours if self.cfg.cores: m.sbatchCores = min(m.sbatchCores, self.cfg.cores) @@ -129,11 +132,12 @@ async def prepare_send(self, op): m.shifterImage = self.cfg.shifter_image return await super().prepare_send(op) - def _agent_env(self): + def _agent_env(self, op): return super()._agent_env( + op, env=PKDict( SIREPO_SRDB_ROOT=self._srdb_root, - ) + ), ) async def _do_agent_start(self, op): @@ -144,7 +148,7 @@ async def _do_agent_start(self, op): set -e mkdir -p '{agent_start_dir}' cd '{self._srdb_root}' -{self._agent_env()} +{self._agent_env(op)} (/usr/bin/env; setsid {self.cfg.sirepo_cmd} job_agent start_sbatch) >& {log_file} & disown """ @@ -246,8 +250,12 @@ def _start_idle_timeout(self): """Sbatch agents should be kept alive as long as possible""" pass - def _websocket_free(self): - self._srdb_root = None + async def free_resources(self, *args, **kwargs): + try: + self._srdb_root = None + except Exception as e: + pkdlog("{} error={} stack={}", self, e, pkdexc()) + return await super().free_resources(*args, **kwargs) CLASS = SbatchDriver diff --git a/sirepo/job_supervisor.py b/sirepo/job_supervisor.py index c6d90b420e..9c9f918387 100644 --- a/sirepo/job_supervisor.py +++ b/sirepo/job_supervisor.py @@ -13,6 +13,7 @@ import asyncio import contextlib import copy +import enum import pykern.pkio import sirepo.const import sirepo.global_resources @@ -64,20 +65,15 @@ _cfg = None -#: how many times restart request when Awaited() raised -_MAX_RETRIES = 10 - - #: POSIT: same as sirepo.reply _REPLY_SR_EXCEPTION_STATE = "srException" _REPLY_ERROR_STATE = "error" _REPLY_STATE = "state" -class Awaited(Exception): - """An await occurred, restart operation""" - - pass +class SlotAllocStatus(enum.Enum): + DID_NOT_AWAIT = 1 + HAD_TO_AWAIT = 2 class ServerReq(PKDict): @@ -106,14 +102,15 @@ def __init__(self, **kwargs): async def alloc(self, situation): if self._value is not None: - return + return SlotAllocStatus.DID_NOT_AWAIT try: self._value = self._q.get_nowait() + return SlotAllocStatus.DID_NOT_AWAIT except tornado.queues.QueueEmpty: pkdlog("{} situation={}", self._op, situation) with self._op.set_job_situation(situation): self._value = await self._q.get() - raise Awaited() + return SlotAllocStatus.HAD_TO_AWAIT def free(self): if self._value is None: @@ -253,22 +250,16 @@ async def receive(cls, req): async def op_run_timeout(self, op): pass - def _create_op(self, opName, req, kind, job_run_mode, **kwargs): + def _create_op(self, op_name, req, kind, job_run_mode, **kwargs): req.kind = kind - o = _Op( + return _Op( _supervisor=self, kind=req.kind, - msg=PKDict(req.copy_content()).pksetdefault(jobRunMode=job_run_mode), - opName=opName, + msg=PKDict(req.copy_content()) + .pksetdefault(jobRunMode=job_run_mode) + .pkupdate(**kwargs), + op_name=op_name, ) - if "dataFileKey" in o.msg: - kwargs["dataFileUri"] = job.supervisor_file_uri( - o.driver.cfg.supervisor_uri, - job.DATA_FILE_URI, - o.msg.pop("dataFileKey"), - ) - o.msg.pkupdate(**kwargs) - return o def _get_running_pending_jobs(self, uid=None): def _filter_jobs(job): @@ -381,13 +372,8 @@ async def _receive_api_beginSession(self, req): c = self._create_op(job.OP_BEGIN_SESSION, req, job.SEQUENTIAL, "sequential") try: await c.prepare_send() - except Awaited: - # OPTIMIZATION: _agent_ready is the first thing that could raise Awaited. - # In the event that it does, the agent is still started, - # so no need to try again after Awaited. - pass finally: - c.destroy(cancel=False) + c.destroy(cancel_task=False) return PKDict() async def _receive_api_globalResources(self, req): @@ -565,7 +551,7 @@ def _purge_job(jid, qcall): ) def set_situation(self, op, situation, exception=None): - if op.opName != job.OP_RUN: + if op.op_name != job.OP_RUN: return s = self.db.jobStatusMessage p = "Exception: " @@ -737,7 +723,9 @@ def _ops_to_cancel(): # compute job. Both can have relevant data in the event of a canceled compute job. # In the case of OP_IO we excpect that the only reason for cancelation is due to # a timeout (max_run_secs reached) in which case we send back "content-too-large". - if not (self.db.isParallel and o.opName in (job.OP_ANALYSIS, job.OP_IO)) + if not ( + self.db.isParallel and o.op_name in (job.OP_ANALYSIS, job.OP_IO) + ) ) if timed_out_op in self.ops: r.add(timed_out_op) @@ -759,41 +747,33 @@ def _ops_to_cancel(): ): # job is not relevant, but let the user know it isn't running return r + internal_error = None candidates = _ops_to_cancel() - c = None - o = set() + # must be after candidates so don't cancel "c" + c = self._create_op(job.OP_CANCEL, req) # No matter what happens the job is canceled self.__db_update(status=job.CANCELED) self._canceled_serial = self.db.computeJobSerial try: - for i in range(_MAX_RETRIES): - try: - o = _ops_to_cancel().intersection(candidates) - if o: - # TODO(robnagler) cancel run_op, not just by jid, which is insufficient (hash) - if not c: - c = self._create_op(job.OP_CANCEL, req) - await c.prepare_send() - elif c: - c.destroy() - c = None - pkdlog("{} cancel={}", self, o) - for x in o: - x.destroy(cancel=True) - if timed_out_op: - self.db.canceledAfterSecs = timed_out_op.max_run_secs - if c: - c.msg.opIdsToCancel = [x.opId for x in o] - c.send() - await c.reply_get() - return r - except Awaited: - pass - else: - raise AssertionError("too many retries {}".format(req)) + # TODO(robnagler) cancel run_op, not just by jid, which is insufficient (hash) + await c.prepare_send() + # Only cancel "old" ops. New ones should not be affected by this cancel. + o = _ops_to_cancel().intersection(candidates) + if not o: + return + pkdlog("{} to_cancel={}", self, o) + if timed_out_op: + self.__db_update(canceledAfterSecs=timed_out_op.max_run_secs) + for x in o: + x.destroy(cancel_task=True) + c.msg.opIdsToCancel = [x.op_id for x in o] + c.send() + await c.reply_get() + return r + except Exception as e: + internal_error = f"_run exception={e}" finally: - if c: - c.destroy(cancel=False) + c.destroy(cancel_task=False, internal_error=internal_error) async def _receive_api_runSimulation(self, req, recursion_depth=0): f = req.content.data.get("forceRun") @@ -834,7 +814,6 @@ async def _receive_api_runSimulation(self, req, recursion_depth=0): computeJobQueued=t, computeJobSerial=t, computeModel=req.content.computeModel, - driverDetails=o.driver.driver_details, # run mode can change between runs so we must update the db jobRunMode=req.content.jobRunMode, simName=req.content.data.models.simulation.name, @@ -880,7 +859,7 @@ async def _receive_api_statefulCompute(self, req): async def _receive_api_statelessCompute(self, req): return await self._send_op_analysis(req, "stateless_compute") - def _create_op(self, opName, req, **kwargs): + def _create_op(self, op_name, req, **kwargs): req.simulationType = self.db.simulationType # run mode can change between runs so use req.content.jobRunMode # not self.db.jobRunMode @@ -890,12 +869,12 @@ def _create_op(self, opName, req, **kwargs): raise sirepo.util.NotFound("invalid jobRunMode={} req={}", r, req) k = ( job.PARALLEL - if self.db.isParallel and opName != job.OP_ANALYSIS + if self.db.isParallel and op_name != job.OP_ANALYSIS else job.SEQUENTIAL ) o = ( super() - ._create_op(opName, req, k, r, **kwargs) + ._create_op(op_name, req, k, r, **kwargs) .pkupdate(task=asyncio.current_task()) ) self.ops.append(o) @@ -920,14 +899,7 @@ def _set_error(compute_job_serial, internal_error): async def _send_op(op, compute_job_serial, prev_db): try: - for _ in range(_MAX_RETRIES): - try: - await op.prepare_send() - break - except Awaited: - pass - else: - raise AssertionError(f"too many retries {op}") + await op.prepare_send() except sirepo.const.ASYNC_CANCELED_ERROR: if self.pkdel("_canceled_serial") != compute_job_serial: # There was a timeout getting the run started. Set the @@ -940,7 +912,7 @@ async def _send_op(op, compute_job_serial, prev_db): pass raise except Exception as e: - op.destroy(cancel=False) + op.destroy(cancel_task=False, internal_error=f"_send_op exception={e}") if isinstance(e, sirepo.util.SRException) and e.sr_args.params.get( "isGeneral" ): @@ -949,6 +921,7 @@ async def _send_op(op, compute_job_serial, prev_db): return False _set_error(compute_job_serial, op.internal_error) raise + self.__db_update(driverDetails=op.driver.driver_details) op.make_lib_dir_symlink() op.send() return True @@ -964,6 +937,9 @@ async def _send_op(op, compute_job_serial, prev_db): self.db.queueState = None # TODO(robnagler) is this ever true? if op != self.run_op: + pkdlog( + "ignore op={} because not run_op={}", op, self.run_op + ) return # run_dir is in a stable state so don't need to lock op.run_dir_slot.free() @@ -991,10 +967,14 @@ async def _send_op(op, compute_job_serial, prev_db): if op == self.run_op: self.__db_update( status=job.ERROR, + internal_error=f"_run exception={e}", error="server error", ) + else: + pkdlog("no db_update op={} because not run_op={}", op, self.run_op) + finally: - op.destroy(cancel=False) + op.destroy(cancel_task=False) async def _send_op_analysis(self, req, jobCmd): pkdlog( @@ -1006,26 +986,24 @@ async def _send_op_analysis(self, req, jobCmd): return await self._send_with_single_reply(job.OP_ANALYSIS, req, jobCmd=jobCmd) - async def _send_with_single_reply(self, opName, req, **kwargs): - o = self._create_op(opName, req, **kwargs) + async def _send_with_single_reply(self, op_name, req, **kwargs): + o = self._create_op(op_name, req, **kwargs) + internal_error = None try: - for i in range(_MAX_RETRIES): - try: - await o.prepare_send() - o.send() - r = await o.reply_get() - # POSIT: any api_* that could run into runDirNotFound - # will call _send_with_single_reply() and this will - # properly format the reply - if r.get("runDirNotFound"): - return self._init_db_missing_response(req) - return r - except Awaited: - pass - else: - raise AssertionError("too many retries {}".format(req)) + await o.prepare_send() + o.send() + r = await o.reply_get() + # POSIT: any api_* that could run into runDirNotFound + # will call _send_with_single_reply() and this will + # properly format the reply + if r.get("runDirNotFound"): + return self._init_db_missing_response(req) + return r + except Exception as e: + internal_error = f"_send_with_single_reply exception={e}" + raise finally: - o.destroy(cancel=False) + o.destroy(cancel_task=False, internal_error=internal_error) def _status_reply(self, req): def res(**kwargs): @@ -1080,45 +1058,70 @@ def __init__(self, *args, **kwargs): self.update( do_not_send=False, internal_error=None, - opId=job.unique_key(), + op_id=job.unique_key(), _reply_q=sirepo.tornado.Queue(), ) if "run_dir_slot_q" in self._supervisor: self.run_dir_slot = self._supervisor.run_dir_slot_q.sr_slot_proxy(self) - self.msg.update(opId=self.opId, opName=self.opName) - self.driver = job_driver.assign_instance_op(self) - self.cpu_slot = self.driver.cpu_slot_q.sr_slot_proxy(self) - q = self.driver.op_slot_q.get(self.opName) - self.op_slot = q and q.sr_slot_proxy(self) - self.max_run_secs = self._get_max_run_secs() + self.msg.update(opId=self.op_id, opName=self.op_name) pkdlog("{} runDir={}", self, self.msg.get("runDir")) - def destroy(self, cancel=True, internal_error=None): - "run_dir_slot" in self and self.run_dir_slot.free() - if cancel and self.get("task"): - self.task.cancel() - self.task = None - # Ops can be destroyed multiple times - # The first error is "closest to the source" so don't overwrite it - if not self.internal_error: - self.internal_error = internal_error + def destroy(self, cancel_task=True, internal_error=None): + """Idempotently destroy op + + Ops can be destroyed multiple times. The first + `internal_error` is "closest to the source" so it won't be + overwritten by subsequent calls unless it is `None`. + + Args: + cancel_task (bool): cancel `self.task` if True [default: True] + internal_error (str): saved for logging in `destroy_op` [default: None] + + """ + if x := self.pkdel("run_dir_slot"): + x.free() + if (x := self.pkdel("task")) and cancel_task: + x.cancel() for x in "run_callback", "timer": - if x in self: - tornado.ioloop.IOLoop.current().remove_timeout(self.pkdel(x)) + if y := self.pkdel(x): + tornado.ioloop.IOLoop.current().remove_timeout(y) + if internal_error and not self.internal_error: + self.internal_error = internal_error self._supervisor.destroy_op(self) - self.driver.destroy_op(self) + if "driver" in self: + self.driver.destroy_op(self) def make_lib_dir_symlink(self): self.driver.make_lib_dir_symlink(self) def pkdebug_str(self): - return pkdformat("_Op({}, {:.4})", self.opName, self.opId) + def _internal_error(): + if not self.internal_error: + return "" + return ", internal_error={self.internal_error}" + + return pkdformat( + "_Op({}, {:.4}{})", self.op_name, self.op_id, _internal_error() + ) async def prepare_send(self): """Ensures resources are available for sending to agent To maintain consistency, do not modify global state before calling this method. """ + if "driver" not in self: + self.driver = job_driver.assign_instance_op(self) + pkdlog("assigned driver={} to op={}", self.driver, self) + self.cpu_slot = self.driver.cpu_slot_q.sr_slot_proxy(self) + if q := self.driver.op_slot_q.get(self.op_name): + self.op_slot = q.sr_slot_proxy(self) + self.max_run_secs = self._get_max_run_secs() + if "dataFileKey" in self.msg: + self.msg.dataFileUri = job.supervisor_file_uri( + self.driver.cfg.supervisor_uri, + job.DATA_FILE_URI, + self.msg.pop("dataFileKey"), + ) await self.driver.prepare_send(self) async def reply_get(self): @@ -1162,14 +1165,14 @@ def set_job_situation(self, situation): def _get_max_run_secs(self): if self.driver.op_is_untimed(self): return 0 - if self.opName in ( + if self.op_name in ( sirepo.job.OP_ANALYSIS, sirepo.job.OP_IO, ): - return _cfg.max_secs[self.opName] + return _cfg.max_secs[self.op_name] if self.kind == job.PARALLEL and self.msg.get("isPremiumUser"): return _cfg.max_secs["parallel_premium"] return _cfg.max_secs[self.kind] def __hash__(self): - return hash((self.opId,)) + return hash((self.op_id,)) diff --git a/sirepo/pkcli/job_agent.py b/sirepo/pkcli/job_agent.py index a7793161c2..748f12d97c 100644 --- a/sirepo/pkcli/job_agent.py +++ b/sirepo/pkcli/job_agent.py @@ -170,11 +170,11 @@ def fastcgi_destroy(self): self._fastcgi_file = None self.fastcgi_cmd = None - def format_op(self, msg, opName, **kwargs): + def format_op(self, msg, op_name, **kwargs): if msg: kwargs["opId"] = msg.get("opId") return pkjson.dump_bytes( - PKDict(agentId=_cfg.agent_id, opName=opName).pksetdefault(**kwargs), + PKDict(agentId=_cfg.agent_id, opName=op_name).pksetdefault(**kwargs), ) async def job_cmd_reply(self, msg, op_name, text): diff --git a/sirepo/status.py b/sirepo/status.py index 5406bb9a1b..d320f4940a 100644 --- a/sirepo/status.py +++ b/sirepo/status.py @@ -54,13 +54,11 @@ async def _run_tests(self): res.destroy() m = re.search(r"/source/(\w+)$", c.uri) if not m: - raise RuntimeError("failed to find sid in resp={}".format(c)) + raise RuntimeError(f"failed to find sid in resp={c}") i = m.group(1) d = simulation_db.read_simulation_json(simulation_type, sid=i, qcall=self) try: - d.models.electronBeam.current = d.models.electronBeam.current + ( - random.random() / 10 - ) + d.models.electronBeam.current += random.random() / 10 except AttributeError: assert ( _cfg.sim_type == "myapp" @@ -76,23 +74,29 @@ async def _run_tests(self): r = resp.content_as_object() resp.destroy() resp = None - pkdlog("resp={}", r) if r.state == "error": - raise RuntimeError("simulation error: resp={}".format(r)) + raise RuntimeError(f"state=error sid={i} resp={r}") if r.state == "completed": + pkdlog("status=completed sid={}", i) if "initialIntensityReport" == d.report: - min_size = 50 - if len(r.z_matrix) < min_size or len(r.z_matrix[0]) < min_size: - raise RuntimeError("received bad report output: resp={}", r) + m = 50 + if len(r.z_matrix) < m: + raise RuntimeError( + f"len(r.z_matrix)={len(r.z_matrix)} < {m} resp={r}", + ) + if len(r.z_matrix[0]) < m: + raise RuntimeError( + f"len(r.z_matrix[0])={len(r.z_matrix[0])} < {m} resp={r}", + ) return - d = r.nextRequest + if (d := r.get("nextRequest")) is None: + raise RuntimeError( + f"nextRequest missing state={r.get('state')} resp={r}" + ) + resp = await self.call_api("runStatus", data=d) await asyncio.sleep(_SLEEP) - raise RuntimeError( - "simulation timed out: seconds={} resp=".format( - _cfg.max_calls * _SLEEP, r - ), - ) + raise RuntimeError(f"timeout={_cfg.max_calls * _SLEEP}s last resp={r}") finally: if resp: resp.destroy() diff --git a/tests/job_timeout_test.py b/tests/job_timeout_test.py index f7a198cde7..988b4808e3 100644 --- a/tests/job_timeout_test.py +++ b/tests/job_timeout_test.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- """test for canceling a long running simulation due to a timeout :copyright: Copyright (c) 2019 RadiaSoft LLC. All Rights Reserved. :license: http://www.apache.org/licenses/LICENSE-2.0.html """ -from __future__ import absolute_import, division, print_function from pykern.pkcollections import PKDict import os import pytest