Skip to content

Commit

Permalink
Support sharding for auto_trainer (PaddlePaddle#8164)
Browse files Browse the repository at this point in the history
* add

* fix

* refine code

* refine

* fix

* fix

* fix

* refine
  • Loading branch information
zhangbo9674 authored Apr 1, 2024
1 parent 4d661bc commit 7b493a8
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 55 deletions.
8 changes: 4 additions & 4 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,14 @@ def create_pretrained_dataset(

train_val_test_num_samples = [
training_args.per_device_train_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.max_steps
* training_args.gradient_accumulation_steps,
training_args.per_device_eval_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.eval_iters
* (training_args.max_steps // training_args.eval_steps + 1),
training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters,
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
]

print_rank_0(" > datasets target sizes (minimum size):")
Expand Down Expand Up @@ -411,7 +411,7 @@ def init_seed(seed: int = 1234, args=None):
topo = Topology(
dist.get_rank(),
dist.get_world_size(),
dp_degree=args.data_parallel_degree,
dp_degree=args.dataset_world_size,
pp_degree=args.pipeline_parallel_degree,
mp_degree=args.tensor_parallel_degree,
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
Expand Down
12 changes: 5 additions & 7 deletions llm/llama/auto_parallel/run_pretrain_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,14 @@ def create_pretrained_dataset(

train_val_test_num_samples = [
training_args.per_device_train_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.max_steps
* training_args.gradient_accumulation_steps,
training_args.per_device_eval_batch_size
* training_args.data_parallel_degree
* training_args.dataset_world_size
* training_args.eval_iters
* (training_args.max_steps // training_args.eval_steps + 1),
training_args.per_device_eval_batch_size * training_args.data_parallel_degree * training_args.test_iters,
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
]

print_rank_0(" > datasets target sizes (minimum size):")
Expand Down Expand Up @@ -421,7 +421,7 @@ def init_seed(seed: int = 1234, args=None):
topo = Topology(
dist.get_rank(),
dist.get_world_size(),
dp_degree=args.data_parallel_degree,
dp_degree=args.dataset_world_size,
pp_degree=args.pipeline_parallel_degree,
mp_degree=args.tensor_parallel_degree,
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
Expand Down Expand Up @@ -600,9 +600,7 @@ def fn(layer):
def loss_func(loss, outputs):
return loss

total_train_batch_size_per_acc_step = (
training_args.per_device_train_batch_size * training_args.data_parallel_degree
)
total_train_batch_size_per_acc_step = training_args.per_device_train_batch_size * training_args.dataset_world_size
total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps

print_config(training_args)
Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .trainer_callback import TrainerState
from .trainer_utils import ( # set_hyrbid_parallel_seed,
PREFIX_CHECKPOINT_DIR,
ShardingOption,
TrainOutput,
_exec_mode_guard,
get_last_checkpoint,
Expand Down Expand Up @@ -111,6 +112,13 @@ def _wrap_for_dist_loader(self, train_dataloader):
def _wrap_for_auto(self, model, train_dataloader):
dist_loader = self._wrap_for_dist_loader(train_dataloader)

if ShardingOption.SHARD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage1())
elif ShardingOption.SHARD_GRAD_OP in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage2())
elif ShardingOption.FULL_SHARD in self.args.sharding:
self.optimizer = dist.shard_optimizer(self.optimizer, dist.ShardingStage3())

if self.args.to_static:
unified_strategy = dist.Strategy()
unified_strategy._from_legacy_strategy(self.args.strategy)
Expand All @@ -119,7 +127,6 @@ def _wrap_for_auto(self, model, train_dataloader):
dist_loader,
)
else:
self.optimizer = dist.shard_optimizer(self.optimizer)
return model, dist_loader

def _wrap_amp_model(self, args, model):
Expand Down
43 changes: 8 additions & 35 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
import paddle.nn as nn
from packaging import version
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
Expand Down Expand Up @@ -1538,38 +1535,14 @@ def apply_decay_param_fun(x):
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
optimizer_kwargs["multi_precision"] = True

def is_new_version_sharding_stage1_optimizer():
signature_keys = set(inspect.signature(DygraphShardingOptimizer).parameters.keys())
return "inner_optimizer_class" not in signature_keys

if ShardingOption.SHARD_OP in self.args.sharding and not is_new_version_sharding_stage1_optimizer():
# for backward compatibility.
# this call will raise, if sharding stage1 is supported in HybridParallelOptimizer,
# in which case, the logic follows will handle it
self.optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=None,
params=params,
inner_optimizer_class=optimizer_cls,
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=self.args.weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
if self.args.max_grad_norm > 0
else None,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
parameters=params,
weight_decay=self.args.weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm)
if self.args.max_grad_norm > 0
else None,
**optimizer_kwargs,
)
self.optimizer = optimizer_cls(
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
apply_decay_param_fun=apply_decay_param_fun,
parameters=params,
weight_decay=self.args.weight_decay,
grad_clip=nn.ClipGradByGlobalNorm(self.args.max_grad_norm) if self.args.max_grad_norm > 0 else None,
**optimizer_kwargs,
)

return self.optimizer

Expand Down
24 changes: 16 additions & 8 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,28 +1194,36 @@ def is_segment_parallel_supported():

elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

assert (
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}."

self.data_parallel_degree = world_size // (self.tensor_parallel_degree * self.pipeline_parallel_degree)

if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = self.data_parallel_degree
self.sharding_parallel_degree = world_size // (
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree
)

self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
if self.sharding_parallel_degree == 1 and len(self.sharding) > 0:
logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!")
self.sharding = []

self.data_parallel_degree = world_size // (
self.sharding_parallel_degree
* self.tensor_parallel_degree
* self.sep_parallel_degree
* self.pipeline_parallel_degree
)

if ShardingOption.OFFLOAD in self.sharding:
warnings.warn("`offload` is not supported NOW!")

strategy = fleet.auto.Strategy()
if self.data_parallel_degree > 1:
if self.dataset_world_size > 1:
data_parallel_config = set(self.data_parallel_config.split(" "))
for x in data_parallel_config:
if len(x) > 0:
Expand Down Expand Up @@ -1356,10 +1364,10 @@ def is_segment_parallel_supported():
self.strategy = strategy
if self.hybrid_parallel_topo_order == "pp_first":
order = ["pp", "dp", "mp"]
degree = [self.pipeline_parallel_degree, self.data_parallel_degree, self.tensor_parallel_degree]
degree = [self.pipeline_parallel_degree, self.dataset_world_size, self.tensor_parallel_degree]
elif self.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "pp", "mp"]
degree = [self.data_parallel_degree, self.pipeline_parallel_degree, self.tensor_parallel_degree]
degree = [self.dataset_world_size, self.pipeline_parallel_degree, self.tensor_parallel_degree]
mesh_dims = list(zip(order, degree))
fleet.auto.create_mesh(mesh_dims)

Expand All @@ -1371,7 +1379,7 @@ def is_segment_parallel_supported():

strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"dp_degree": self.dataset_world_size,
"mp_degree": self.tensor_parallel_degree,
"pp_degree": self.pipeline_parallel_degree,
"order": order,
Expand Down Expand Up @@ -1526,7 +1534,7 @@ def dataset_world_size(self):
if self.use_hybrid_parallel:
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
elif self.enable_auto_parallel:
return max(self.data_parallel_degree, 1)
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
else:
return paddle.distributed.get_world_size()

Expand Down

0 comments on commit 7b493a8

Please sign in to comment.