diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 6015cf8cd..d63f717a5 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -532,61 +532,6 @@ def test_framework_initialize_and_trains_with_aadp(): assert spy["augmentation_calls"] == 1 assert spy["get_ready_for_train_calls"] == 1 - -@pytest.mark.skipif( - not is_fms_accelerate_available(plugins="aadp"), - reason="Only runs if fms-accelerate is installed along with \ - attention_and_distributed_packing plugin", -) -def test_padding_free_plugin_raises_error_with_untokenized_dataset(): - """ - Currently sft_trainer uses DataCollatorForCompletionOnlyLM for unformatted, - untokenized datasets. It uses a DataCollatorForSeq2Seq as default for pretokenized - datasets. - Ensure that padding free plugin will raise an error when an untokenized - dataset is passed to the padding-free plugin when it checks the data collator. - """ - - with tempfile.TemporaryDirectory() as tempdir: - - model_args = copy.deepcopy(MODEL_ARGS) - model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3" - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.output_dir = tempdir - train_args.save_strategy = "no" - data_args = copy.deepcopy(DATA_ARGS) - data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT - data_args.response_template = "\n### Response:" - data_args.dataset_text_field = "output" - - # initialize a config - aadp_config = AttentionAndDistributedPackingConfig( - padding_free=PaddingFree(method="huggingface") - ) - - with pytest.raises( - TypeError, - match="The padding-free plugin currently only works with a \ - `DataCollatorForSeq2Seq` collate_fn", - ): - with build_framework_and_maybe_instantiate( - [ - ( - ["training.attention.padding_free"], - PaddingFreeAccelerationPlugin, - ), - ], - instantiate=False, - ): - with instantiate_model_patcher(): - sft_trainer.train( - model_args, - data_args, - train_args, - attention_and_distributed_packing_config=aadp_config, - ) - - def test_error_raised_with_paddingfree_and_flash_attn_disabled(): """Ensure error raised when padding-free is not used with flash attention""" with pytest.raises(