diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index e984c05cbc0b..8ebd218447fc 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -253,6 +253,7 @@ class TrainingArguments: enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration. enable_overlap_p2p_comm, overlap p2p communication with computation. enable_clear_every_step_cache, clear every step cache for pipeline parallel. + disable_non_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. sharding_parallel_config (`str`, *optional*)( Some additional config it highly affect the useage of sharding parallel, we provide some option to config it. following config is support: @@ -616,6 +617,7 @@ class TrainingArguments: "enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n" "enable_overlap_p2p_comm, overlap p2p communication with computation. \n" "enable_clear_every_step_cache, clear every step cache for pipeline parallel. \n" + "disable_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. \n" ) }, ) @@ -993,6 +995,7 @@ def __post_init__(self): "enable_dp_comm_overlap", "enable_clear_every_step_cache", "enable_overlap_p2p_comm", + "disable_batch_p2p_comm", ]: raise ValueError( f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv." @@ -1025,6 +1028,7 @@ def __post_init__(self): "release_gradients": "enable_release_grads" in pipeline_parallel_config, "overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config, "clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config, + "use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config, } if dygraph_pp_configs["dp_comm_overlap"]: raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue @@ -1249,6 +1253,7 @@ def is_segment_parallel_supported(): # "enable_dp_comm_overlap", # no implemenation for auto_parallel # "enable_sharding_comm_overlap", # no implemenation for auto_parallel # "enable_timer", # no implemenation for auto_parallel + # "disable_batch_p2p_comm", # no implemenation for auto_parallel ]: raise ValueError( f"Found unknown pipeline mode config {x}, accpet config is enable_send_recv_overlap."