diff --git a/paxml/programs.py b/paxml/programs.py index a85fabdab..029f0adb3 100644 --- a/paxml/programs.py +++ b/paxml/programs.py @@ -376,7 +376,12 @@ def run(self, state: TrainState, step: int) -> TrainProgramOutput: logging.log_first_n( logging.INFO, '[PAX STATUS]: Writing summaries (attempt).', 5 ) - steps_per_sec = self._maybe_write_summaries(step, new_step, train_outputs) + steps_per_sec = self._maybe_write_summaries( + step, + new_step, + train_outputs, + train_p.compute_steps_per_sec_interval_steps + ) # Run eval at regular step interval. # While the eval ones below are post-model weight updates, hence we use the @@ -456,11 +461,12 @@ def train_input_partition_spec( """Returns the partition spec for the model training inputs.""" def _maybe_write_summaries( - self, step: int, new_step: int, train_outputs: StepFnOutput + self, + step: int, + new_step: int, + train_outputs: StepFnOutput, + compute_steps_per_sec_interval_steps: int ) -> float | None: - # Compute steps/sec every this many steps, revisit when necessary. - compute_steps_per_sec_interval_steps = 10 - steps_per_sec = None should_compute_steps_per_sec = ( new_step % compute_steps_per_sec_interval_steps == 0 diff --git a/paxml/tasks_lib.py b/paxml/tasks_lib.py index 2b51f64ca..811dc54d4 100644 --- a/paxml/tasks_lib.py +++ b/paxml/tasks_lib.py @@ -1240,6 +1240,8 @@ class Train: decode_start_after_n_steps: Starts decoder after N steps. decode_use_ema_states: If True, use ema states to run decode during train, note that in this case ema MUST be enabled in the learner. + compute_steps_per_sec_interval_steps: number of steps to average over when + computing steps/sec. profiler_num_steps: The number of steps to be captured by the profiler based on the step time estimate. profiler_min_duration_sec: The minimum duration to be captured by the @@ -1302,6 +1304,7 @@ class Train: decode_start_after_n_steps: int = 0 # TODO(zhishuai): verify this for a pjit model. decode_use_ema_states: bool = False + compute_steps_per_sec_interval_steps: int = 10 profiler_num_steps: int = 2 profiler_min_duration_sec: float = 1.0 profiler_capture_step: int | None = None