Skip to content

Commit

Permalink
Save batch to disk on OOM. (#343)
Browse files Browse the repository at this point in the history
* Save batch to disk on OOM.

* minor fixes

* Fixes after review.

* Fix style issues.
  • Loading branch information
csukuangfj authored May 5, 2022
1 parent 9ddbc68 commit e1c3e98
Showing 1 changed file with 60 additions and 22 deletions.
82 changes: 60 additions & 22 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,16 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need to "
"be changed.",
)

parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""",
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)

parser.add_argument(
Expand Down Expand Up @@ -670,25 +671,29 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info

# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise

if params.print_diagnostics and batch_idx == 5:
return
Expand Down Expand Up @@ -933,6 +938,38 @@ def remove_short_and_long_utt(c: Cut):
cleanup_dist()


def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4

filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)

supervisions = batch["supervisions"]
features = batch["inputs"]

logging.info(f"features shape: {features.shape}")

y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")


def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
Expand Down Expand Up @@ -973,6 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
raise


Expand Down

0 comments on commit e1c3e98

Please sign in to comment.