Skip to content

Commit

Permalink
Merge pull request #69 from nvjax-svc-0:patch/configure-steps-per-sec…
Browse files Browse the repository at this point in the history
…-interval

PiperOrigin-RevId: 608753065
  • Loading branch information
pax authors committed Feb 20, 2024
2 parents f11a938 + 0720e17 commit 5297b4f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
16 changes: 11 additions & 5 deletions paxml/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,12 @@ def run(self, state: TrainState, step: int) -> TrainProgramOutput:
logging.log_first_n(
logging.INFO, '[PAX STATUS]: Writing summaries (attempt).', 5
)
steps_per_sec = self._maybe_write_summaries(step, new_step, train_outputs)
steps_per_sec = self._maybe_write_summaries(
step,
new_step,
train_outputs,
train_p.compute_steps_per_sec_interval_steps,
)

# Run eval at regular step interval.
# While the eval ones below are post-model weight updates, hence we use the
Expand Down Expand Up @@ -458,11 +463,12 @@ def train_input_partition_spec(
"""Returns the partition spec for the model training inputs."""

def _maybe_write_summaries(
self, step: int, new_step: int, train_outputs: StepFnOutput
self,
step: int,
new_step: int,
train_outputs: StepFnOutput,
compute_steps_per_sec_interval_steps: int = 10,
) -> float | None:
# Compute steps/sec every this many steps, revisit when necessary.
compute_steps_per_sec_interval_steps = 10

steps_per_sec = None
should_compute_steps_per_sec = (
new_step % compute_steps_per_sec_interval_steps == 0
Expand Down
3 changes: 3 additions & 0 deletions paxml/tasks_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,8 @@ class Train:
decode_start_after_n_steps: Starts decoder after N steps.
decode_use_ema_states: If True, use ema states to run decode during train,
note that in this case ema MUST be enabled in the learner.
compute_steps_per_sec_interval_steps: number of steps to average over when
computing steps/sec.
profiler_num_steps: The number of steps to be captured by the profiler
based on the step time estimate.
profiler_min_duration_sec: The minimum duration to be captured by the
Expand Down Expand Up @@ -1302,6 +1304,7 @@ class Train:
decode_start_after_n_steps: int = 0
# TODO(zhishuai): verify this for a pjit model.
decode_use_ema_states: bool = False
compute_steps_per_sec_interval_steps: int = 10
profiler_num_steps: int = 2
profiler_min_duration_sec: float = 1.0
profiler_capture_step: int | None = None
Expand Down

0 comments on commit 5297b4f

Please sign in to comment.