-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ParallelLoader with Flexible Tensor Accumulation #8616
Comments
There's some added complexity to how we can stack any arbitrary dataset, but that is implementation specific. |
cc: @tengyifei, @ManfeiBai, do you folks know who can evaluate any risk in the proposal? I'll work on a draft change if none. |
I am trying to understand the motivation better. Let's say that there are 4 gradient accumulation steps. What's the issue with just calling optimizer.step() after 4 steps for dataloader iterations. For e.g.
|
@bhavya01 The motivation lies on how we batch the data in each iteration. In your example, we're traditionally accumulating the gradients for each training step. Generally, the mark step can be either within the optimizer branch (and we'll generate a large HLO tracing the same train_step 4x), or have a mark step outside, and we cut the graph and decouple any global step optimizations (at least in a simple manner). We have recently introduced the gradient accumulation API, which allows us to leverage the XLA while loop to accumulate the gradients over N training steps. Similar to scan, this has various benefits by condensing a repetitive set of ops (note:
If you were to backport this to trace and accumulate the same train_step 4 times, we'd have:
but we want to consolidate the inner loop with an XLA while loop. If we currently gather data with [gradient_accumulation_steps * batch_size, ...], particularly with SPMD, it encounters various sets of issues as mentioned above. |
🚀 Feature
Currently, when loading device data with multi-processing, we leverage
MpDeviceLoader
, which takes care of copying the tensors to the device (returningper_device_loader
), wrapping an existing data loader. Users of torch-xla's MpDeviceLoader face constraints when requiring specific tensor shapes for operations like gradient accumulation. The current implementation only supports having the batch dimension along a single dimension, forcing users to manually reshape data and handle sharding, which can lead to suboptimal cases (see below), particularly because the data is transferred to the device as is.Extend ParallelLoader with a flexible accumulation feature that:
For instance, provided
accumulation_dim=0
,accumulation_size=16
andbatch_dim=1
, with the underlying train loader returning 32 batches at a time, we expected to get[16, 32, seq_dim]
. Similarly,accumulation_dim=1
,accumulation_size=16
would return[32, 16, seq_dim]
, sincebatch_dim
is currently defaulted to 0. If thebatch_dim
andaccumulation_dim
are equal, we throw an error.Hence, the customer can specify a train loader with the intended batch size (not counting in for the accumulation), and allowing them to explicitly specify the accumulation dimension and size for which the batch is stacked on. A non-breaking requirement is that customer need to specify
batch_dim
, if they useaccumulation_size
> 0, since it currently defaults to0
.Motivation
The current data loader is tailored to pull in a certain number of batches along the batch dimension (defaulted to 0). However, in some cases, the training may require a different or specific shape, such as with gradient accumulation (e.g. [4, 4, 4096], instead of [16, 4096] with batch size = 16, for instance DP = 4 and 4 gradient accumulation steps). The problem with the current setup, is that it adds the constraint that the data is accumulated all across a single dimension, and that itself is later sent to the device.
Pitch
Instead, this proposes doing:
Alternatives
Additional context
TODO - add real-life example
The text was updated successfully, but these errors were encountered: