Skip to content

Commit

Permalink
support create_checkpoint_symlink (#2975)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 23, 2025
1 parent ced0654 commit 244cfc7
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
- add_version: 在output_dir上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True
- resume_only_model: 如果resume_from_checkpoint,仅resume模型权重,默认为False
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False
- create_checkpoint_symlink: 额外创建checkpoint软链接。best_model和last_model分别为f'{output_dir}/best'和f'{output_dir}/last'
- loss_type: loss类型,默认使用模型自带损失函数

- packing: 是否使用packing,默认为False
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
- add_version: Add directory to output_dir with `'<version>-<timestamp>'` to prevent weight overwrite, default is True.
- resume_only_model: If resume_from_checkpoint, only resume model weights, default is False.
- check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False.
- create_checkpoint_symlink: Create additional checkpoint symlinks. best_model and last_model are f'{output_dir}/best' and f'{output_dir}/last', respectively.
- loss_type: Type of loss, default uses the model's built-in loss function.
- packing: Whether to use packing, default is False.
- 🔥lazy_tokenize: Whether to use lazy_tokenize, default is False during LLM training, default is True during MLLM training.
Expand Down
1 change: 1 addition & 0 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TrainArguments(TorchAccArguments, TunerArguments, Seq2SeqTrainingOverrideA
add_version: bool = True
resume_only_model: bool = False
check_model: bool = True
create_checkpoint_symlink: bool = False

# dataset
packing: bool = False
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
'and another piece of data will be randomly selected.')
self._traceback_counter += 1

raise ValueError('Failed to retrieve the dataset. You can avoid this issue by increasing `max_length` or '
'modifying the `truncation_strategy`.')

def __len__(self) -> int:
return len(self.dataset)

Expand Down
3 changes: 1 addition & 2 deletions swift/llm/export/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False)
output_dir = getattr(args, 'output_dir', None) or f'{args.adapters[0]}-merged'
if os.path.exists(output_dir) and not replace_if_exists:
logger.info(f'The weight directory for the merged LoRA already exists in {output_dir}, '
'skipping the saving process. '
'you can pass `replace_if_exists=True` to overwrite it.')
'skipping the saving process.')
else:
origin_device_map = args.device_map
args.device_map = device_map or args.device_map
Expand Down
7 changes: 7 additions & 0 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def _save_trainer_state(self, trainer):
training_args = trainer.args
state = trainer.state

if self.args.create_checkpoint_symlink:
last_checkpoint = os.path.join(self.args.output_dir, 'last')
best_checkpoint = os.path.join(self.args.output_dir, 'best')
os.symlink(state.last_model_checkpoint, last_checkpoint)
os.symlink(state.best_model_checkpoint, best_checkpoint)
state.last_model_checkpoint = last_checkpoint
state.best_model_checkpoint = best_checkpoint
logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}')
logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}')

Expand Down

0 comments on commit 244cfc7

Please sign in to comment.