diff --git a/.github/actions/setup/action.yaml b/.github/actions/setup/action.yaml index 488a190..65657cf 100644 --- a/.github/actions/setup/action.yaml +++ b/.github/actions/setup/action.yaml @@ -8,4 +8,4 @@ runs: - name: "Install libsndfile" shell: bash run: | - sudo apt-get install libsndfile1 python3-distutils \ No newline at end of file + sudo apt-get install libsndfile1 \ No newline at end of file diff --git a/lcm/datasets/base.py b/lcm/datasets/base.py index 4c2c50e..14054f2 100644 --- a/lcm/datasets/base.py +++ b/lcm/datasets/base.py @@ -53,9 +53,9 @@ def pipeline(self) -> DataPipeline: self._pipeline = self.builder_func( self.datasets, self.data_config, gang_rank, world_size ) - assert ( - self._pipeline - ), f"Cannot build data pipeline from config {self.data_config}" + assert self._pipeline, ( + f"Cannot build data pipeline from config {self.data_config}" + ) return self._pipeline def destroy(self) -> None: diff --git a/lcm/datasets/batch.py b/lcm/datasets/batch.py index 0147bf5..eca937f 100644 --- a/lcm/datasets/batch.py +++ b/lcm/datasets/batch.py @@ -249,9 +249,9 @@ def __post_init__(self): length = len(self.source) - assert ( - (self.target is None) or (len(self.target) == length) - ), f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}" + assert (self.target is None) or (len(self.target) == length), ( + f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}" + ) def __len__(self) -> int: return len(self.source) @@ -296,9 +296,9 @@ def prepare_input( ) elif style == LCMStyle.SUPERVISED: - assert ( - self.target is not None - ), "Missing target embeddings for a supervised batch" + assert self.target is not None, ( + "Missing target embeddings for a supervised batch" + ) return get_embeddings_sequence( src_seqs=self.source, tgt_seqs=self.target, diff --git a/lcm/datasets/dataloader.py b/lcm/datasets/dataloader.py index 1191e1d..8922045 100644 --- a/lcm/datasets/dataloader.py +++ b/lcm/datasets/dataloader.py @@ -195,9 +195,9 @@ def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput: else: embs = None outputs[key] = embs - assert ( - outputs["source"] is not None - ), "LCMDataLoader requires `source` sequences to be present in batches" + assert outputs["source"] is not None, ( + "LCMDataLoader requires `source` sequences to be present in batches" + ) return LCMInput(**outputs) def iterate_batches(self) -> Iterator[LCMInput]: diff --git a/lcm/datasets/dataloading.py b/lcm/datasets/dataloading.py index 3b48e5c..b14b6d5 100644 --- a/lcm/datasets/dataloading.py +++ b/lcm/datasets/dataloading.py @@ -202,9 +202,9 @@ def build_dataload_pipeline( self, rank: int = 0, world_size: int = 1 ) -> DataPipelineBuilder: if world_size > 1: - assert ( - self.loading_config.seed is not None - ), "for distributed training with `world_size` > 1, `seed` should be set !" + assert self.loading_config.seed is not None, ( + "for distributed training with `world_size` > 1, `seed` should be set !" + ) if self.is_validation: self.set_validation_params(world_size) @@ -321,12 +321,12 @@ def create_on_the_fly_columns( self, pipeline: DataPipelineBuilder ) -> DataPipelineBuilder: if self.dataset_config.source_sequences is not None: - assert ( - self.dataset_config.source_column is not None - ), f"Expected a source_column - found {self.dataset_config.source_column}" - assert ( - self.dataset_config.source_text_column is not None - ), f"Expected a source_text_column - found {self.dataset_config.source_text_column}" + assert self.dataset_config.source_column is not None, ( + f"Expected a source_column - found {self.dataset_config.source_column}" + ) + assert self.dataset_config.source_text_column is not None, ( + f"Expected a source_text_column - found {self.dataset_config.source_text_column}" + ) pipeline = pipeline.map( partial( @@ -338,12 +338,12 @@ def create_on_the_fly_columns( num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments), ) if self.dataset_config.target_sequences is not None: - assert ( - self.dataset_config.target_column is not None - ), f"Expected a target_column, found {self.dataset_config.target_column}" - assert ( - self.dataset_config.target_text_column is not None - ), f"Expected a target_text_columns, found {self.dataset_config.target_text_column}" + assert self.dataset_config.target_column is not None, ( + f"Expected a target_column, found {self.dataset_config.target_column}" + ) + assert self.dataset_config.target_text_column is not None, ( + f"Expected a target_text_columns, found {self.dataset_config.target_text_column}" + ) pipeline = pipeline.map( partial( @@ -426,9 +426,9 @@ def config_post_init(self) -> None: ) if self.loading_config.even_sharding: - assert ( - self.loading_config.seed is not None - ), "`even_sharding` sharding requires to seed to be set" + assert self.loading_config.seed is not None, ( + "`even_sharding` sharding requires to seed to be set" + ) if self.loading_config.max_tokens == 0: self.loading_config.max_tokens = None @@ -876,9 +876,9 @@ def add_min_max_sentence_len_in_doc_filter( self.loading_config.max_sentence_len_in_doc or self.loading_config.min_sentence_len_in_doc ): - assert ( - self.dataset_config.source_text_column is not None - ), f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}" + assert self.dataset_config.source_text_column is not None, ( + f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}" + ) pipeline = pipeline.map( partial( @@ -962,9 +962,9 @@ def add_quality_score_filters( if source_quality_range is None: return pipeline - assert ( - self.dataset_config.source_quality_column is not None - ), f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}" + assert self.dataset_config.source_quality_column is not None, ( + f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}" + ) pipeline = pipeline.map( partial( diff --git a/lcm/datasets/parquet_utils.py b/lcm/datasets/parquet_utils.py index 12c4396..61eb5f2 100644 --- a/lcm/datasets/parquet_utils.py +++ b/lcm/datasets/parquet_utils.py @@ -486,9 +486,9 @@ def build_batching_loop_over_one_table( num_parallel_calls: int = 1, ) -> DataPipeline: if max_tokens is not None: - assert ( - length_column is not None - ), "Need to provide a column to compute the number of tokens" + assert length_column is not None, ( + "Need to provide a column to compute the number of tokens" + ) random_state = np.random.RandomState(seed) if length_column is not None and len(length_column) > 0: @@ -1109,9 +1109,9 @@ def get_row_group_level_metadata( columns_to_exclude = set(["row_group_id", "num_rows", "total_byte_size"]) & set( columns ) - assert ( - len(columns_to_exclude) == 0 - ), f"names conflict, rename/remove : {columns_to_exclude}" + assert len(columns_to_exclude) == 0, ( + f"names conflict, rename/remove : {columns_to_exclude}" + ) def get_one_row_group_stats(row_group): metadata = row_group.metadata diff --git a/lcm/evaluation/arun.py b/lcm/evaluation/arun.py index 527999e..4a047f2 100644 --- a/lcm/evaluation/arun.py +++ b/lcm/evaluation/arun.py @@ -114,12 +114,12 @@ def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0): ) if iteration_value is not None: - assert ( - isinstance(iteration_value, int) and self.config.nshards - ), f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})" - assert ( - self.config.data_loading - ), f"Data loading is not specified: \n {self.config}" + assert isinstance(iteration_value, int) and self.config.nshards, ( + f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})" + ) + assert self.config.data_loading, ( + f"Data loading is not specified: \n {self.config}" + ) self.config.data_loading.rank = iteration_value self.config.data_loading.world_size = int(self.config.nshards) @@ -194,9 +194,9 @@ async def schedule_task( result = (metrics, result_file) result_metrics, result_file = result - assert isinstance( - result_metrics, dict - ), f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}" + assert isinstance(result_metrics, dict), ( + f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}" + ) metrics = {} cf = getattr(module.config, "confidence_level", None) diff --git a/lcm/evaluation/cli/configs.py b/lcm/evaluation/cli/configs.py index aa98a46..150a618 100644 --- a/lcm/evaluation/cli/configs.py +++ b/lcm/evaluation/cli/configs.py @@ -78,9 +78,9 @@ class CliConfig: def __post_init__(self) -> None: self.metric_log_dir = self.metric_log_dir or self.dump_dir - assert ( - self.temperature >= 0.0 - ), f"Expect non-zero temperature, get {self.temperature}" + assert self.temperature >= 0.0, ( + f"Expect non-zero temperature, get {self.temperature}" + ) if self.temperature == 0: self.top_p = 0 self.top_k = 0 diff --git a/lcm/evaluation/metrics/__init__.py b/lcm/evaluation/metrics/__init__.py index ec5fe4f..6db56fd 100644 --- a/lcm/evaluation/metrics/__init__.py +++ b/lcm/evaluation/metrics/__init__.py @@ -57,9 +57,9 @@ def get_scorer( if "outputs" in defaults: output_columns = defaults["outputs"].default else: - assert ( - config.model_name - ), f"Cannot resolve output name for the scorer type {scorer_cls}" + assert config.model_name, ( + f"Cannot resolve output name for the scorer type {scorer_cls}" + ) output_columns = scorer_cls.default_outputs(config.model_name) if isinstance(output_columns, str): diff --git a/lcm/evaluation/metrics/multilingual_similarity.py b/lcm/evaluation/metrics/multilingual_similarity.py index 9d0efd9..53291ed 100644 --- a/lcm/evaluation/metrics/multilingual_similarity.py +++ b/lcm/evaluation/metrics/multilingual_similarity.py @@ -68,9 +68,9 @@ def translate( ) -> List[str]: src_lang, tgt_lang = self.src_lang, self.tgt_lang sent_translations = [] - assert isinstance( - self.model, EncoderDecoderModel - ), f"Unsupported type: {type(self.model)}" + assert isinstance(self.model, EncoderDecoderModel), ( + f"Unsupported type: {type(self.model)}" + ) generator = BeamSearchSeq2SeqGenerator( self.model, echo_prompt=True, max_seq_len=self.max_seq_len ) diff --git a/lcm/evaluation/metrics/sentence_fluency.py b/lcm/evaluation/metrics/sentence_fluency.py index 8b72bc3..c897c4f 100644 --- a/lcm/evaluation/metrics/sentence_fluency.py +++ b/lcm/evaluation/metrics/sentence_fluency.py @@ -200,9 +200,9 @@ def score_texts( bos_token = bos_token or getattr(self.tokenizer, "bos_token", "\n") if eos_token != "": eos_token = eos_token or getattr(self.tokenizer, "eos_token", "\n") - assert ( - eos_token is not None and bos_token is not None - ), "Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''" + assert eos_token is not None and bos_token is not None, ( + "Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''" + ) logger.info( f"Computing perplexity with bos_token={repr(bos_token)} and eos_token={repr(eos_token)}" ) @@ -340,9 +340,9 @@ def backtranslate( translations = [] back_translations = [] losses = [] - assert isinstance( - self.model, EncoderDecoderModel - ), f"Unsupported type: {type(self.model)}" + assert isinstance(self.model, EncoderDecoderModel), ( + f"Unsupported type: {type(self.model)}" + ) generator = BeamSearchSeq2SeqGenerator( self.model, echo_prompt=True, max_seq_len=self.max_seq_len ) diff --git a/lcm/evaluation/predictors/__init__.py b/lcm/evaluation/predictors/__init__.py index 398adf9..c9e0ffe 100644 --- a/lcm/evaluation/predictors/__init__.py +++ b/lcm/evaluation/predictors/__init__.py @@ -42,9 +42,9 @@ def build_predictor( if isinstance(predictor_config, PredictorConfig): config_cls = predictor_config.__class__ else: - assert ( - predictor_type is not None - ), f"Cannot infer predictor from config type {type(predictor_config)}" + assert predictor_type is not None, ( + f"Cannot infer predictor from config type {type(predictor_config)}" + ) config_cls = get_config_cls(predictor_type) predictor_config = promote_config(predictor_config, config_cls) diff --git a/lcm/evaluation/predictors/lcm.py b/lcm/evaluation/predictors/lcm.py index 15aa8a0..9d27a17 100644 --- a/lcm/evaluation/predictors/lcm.py +++ b/lcm/evaluation/predictors/lcm.py @@ -209,9 +209,9 @@ def __call__( if max_gen_len is None: max_gen_len = self.config.max_seq_len - assert isinstance( - prompts, Batched - ), f"Expect sequence of prompts, get {type(prompts)}" + assert isinstance(prompts, Batched), ( + f"Expect sequence of prompts, get {type(prompts)}" + ) # Extract the input embeddings seqs: List[torch.Tensor] = [] @@ -226,9 +226,9 @@ def __call__( ) seqs.append(prompt_embs) else: - assert ( - isinstance(prompts, list) and isinstance(prompts[0], torch.Tensor) - ), f"Expect sonarized prompts in the form or List[torch.Tensor], get {type(prompts)}" + assert isinstance(prompts, list) and isinstance(prompts[0], torch.Tensor), ( + f"Expect sonarized prompts in the form or List[torch.Tensor], get {type(prompts)}" + ) seqs = prompts if max_prompt_len: diff --git a/lcm/evaluation/run.py b/lcm/evaluation/run.py index 099723f..c340c41 100644 --- a/lcm/evaluation/run.py +++ b/lcm/evaluation/run.py @@ -190,9 +190,9 @@ def run_task( task_config = TaskRegistry.get_config(run_config.task_name, **run_config.params) data_loader_type = TaskRegistry.get_dataloader_type(run_config.task_name) - assert isinstance( - run_config.data_loading, EvaluationDataLoadingConfig - ), "data loading not specified" + assert isinstance(run_config.data_loading, EvaluationDataLoadingConfig), ( + "data loading not specified" + ) task = build_task( task_config, data_loading_config=run_config.data_loading, diff --git a/lcm/evaluation/tasks/base.py b/lcm/evaluation/tasks/base.py index 56427d0..4e51519 100644 --- a/lcm/evaluation/tasks/base.py +++ b/lcm/evaluation/tasks/base.py @@ -103,13 +103,13 @@ def load_few_shots( few_shot_examples = config.few_shot_examples if config.num_few_shot > 0: if few_shot_examples is None: - assert ( - config.few_shot_file - ), f"Expect non-empty few_shot_file when few_shot = {config.num_few_shot}" + assert config.few_shot_file, ( + f"Expect non-empty few_shot_file when few_shot = {config.num_few_shot}" + ) assert data_loader, "Expect non-empty data loader" - assert isinstance( - data_loader.data_config, EvaluationDataLoadingConfig - ), f"unexpected data loading type: {type(data_loader.data_config)}" + assert isinstance(data_loader.data_config, EvaluationDataLoadingConfig), ( + f"unexpected data loading type: {type(data_loader.data_config)}" + ) few_shot_data_loader = JSONTestDataLoader( data_config=data_loader.data_config, dataset=JSONDatasetConfig(file_path=config.few_shot_file), @@ -487,9 +487,9 @@ def run( # type: ignore[override] x["nlls"] = defaultdict(list) for cix, (text, pred) in enumerate(zip(x["choice_texts"], preds)): assert pred.tokens and pred.logprobs and pred.text_offsets, pred - assert isinstance( - pred.text, str - ), "multiple texts output is not supported in LLM predictor" + assert isinstance(pred.text, str), ( + "multiple texts output is not supported in LLM predictor" + ) logprobs = pred.logprobs[ text_index(pred.text, pred.text_offsets, text) ] @@ -501,9 +501,9 @@ def run( # type: ignore[override] assert len(preds) == 2 * len(x["choice_texts"]) compl = preds[cix + len(x["choice_texts"])] assert compl.tokens and compl.logprobs and compl.text_offsets - assert isinstance( - compl.text, str - ), "multiple texts output is not supported in LLM predictor" + assert isinstance(compl.text, str), ( + "multiple texts output is not supported in LLM predictor" + ) slice = text_index(compl.text, compl.text_offsets, text) nll_compl = -sum(logprobs) + sum(compl.logprobs[slice]) x["nlls"]["completion"].append(nll_compl) diff --git a/lcm/evaluation/utils/common.py b/lcm/evaluation/utils/common.py index aa8d2b7..ce96d43 100644 --- a/lcm/evaluation/utils/common.py +++ b/lcm/evaluation/utils/common.py @@ -213,8 +213,7 @@ def string_format(template: str, skip_validation: bool = True, **kwargs: Any) -> variables = [k[1] for k in string.Formatter().parse(template) if k[1]] if not all(k in kwargs for k in variables): raise ValueError( - f"Expected: {variables}, got: {sorted(kwargs)}.\n" - f"Template:\n{template}" + f"Expected: {variables}, got: {sorted(kwargs)}.\nTemplate:\n{template}" ) # `Dict[Optional[str], typing.Any]`. kwargs = {k: kwargs[k] for k in variables} @@ -580,9 +579,9 @@ def flatten_dict(data: Dict[str, Any], prefix: str = "") -> Dict[str, Any]: if isinstance(value, dict): flat_dict.update(flatten_dict(value, f"{prefix}{key}/")) elif isinstance(value, Tensor): - assert ( - len(value.size()) == 0 - ), f"Only scalar tensor can be formatted, get {value}" + assert len(value.size()) == 0, ( + f"Only scalar tensor can be formatted, get {value}" + ) flat_dict[f"{prefix}{key}"] = value.item() elif isinstance(value, AverageMetric): flat_dict[f"{prefix}{key}"] = value.value diff --git a/lcm/evaluation/utils/data_utils.py b/lcm/evaluation/utils/data_utils.py index 1e17e6e..418e852 100644 --- a/lcm/evaluation/utils/data_utils.py +++ b/lcm/evaluation/utils/data_utils.py @@ -625,9 +625,9 @@ def default_lcm_postprocess( ) -> Example: # Get the best hypothesis prediction_text = x[PREDICTION_TEXT_COLUMN][0] - assert isinstance( - prediction_text, list - ), f"LCM prediction texts are list of sentences, got {type(prediction_text)}" + assert isinstance(prediction_text, list), ( + f"LCM prediction texts are list of sentences, got {type(prediction_text)}" + ) preds = prediction_text diff --git a/lcm/evaluation/utils/segment_alignment.py b/lcm/evaluation/utils/segment_alignment.py index cca97e3..c216f9a 100644 --- a/lcm/evaluation/utils/segment_alignment.py +++ b/lcm/evaluation/utils/segment_alignment.py @@ -262,9 +262,9 @@ def align_outputs_with_paragraphs( output_columns = [ c for c in output_df.columns if c.startswith(output_column_prefix) ] - assert ( - len(output_columns) > 0 - ), f"No output columns starting with the output_column_prefix='{output_column_prefix}' found." + assert len(output_columns) > 0, ( + f"No output columns starting with the output_column_prefix='{output_column_prefix}' found." + ) # 1.2. Read the human alignments paragraphs_ids = [] @@ -283,13 +283,13 @@ def align_outputs_with_paragraphs( ] # 1.3. The order of the documents might be different. Align it! - assert ( - len(machine_inputs) == len(paragraphs_texts) - ), f"Machine outputs contain {len(machine_inputs)} documents, but human paragraphs contain {len(paragraphs_texts)} documents." + assert len(machine_inputs) == len(paragraphs_texts), ( + f"Machine outputs contain {len(machine_inputs)} documents, but human paragraphs contain {len(paragraphs_texts)} documents." + ) machine_ids = match_text_pairs(paragraphs_texts, machine_inputs) - assert len(machine_ids) == len( - set(machine_ids) - ), "Machine and human outputs don't seem to be 1:1 alignable." + assert len(machine_ids) == len(set(machine_ids)), ( + "Machine and human outputs don't seem to be 1:1 alignable." + ) named_machine_outputs: Dict[str, List[str]] = { column: output_df[column].loc[machine_ids].tolist() for column in output_columns diff --git a/lcm/inference/lcm/generator.py b/lcm/inference/lcm/generator.py index bac3f9c..3e81b90 100644 --- a/lcm/inference/lcm/generator.py +++ b/lcm/inference/lcm/generator.py @@ -90,9 +90,9 @@ def __init__( self.max_seq_len = options.max_seq_len self.min_seq_len = options.min_seq_len - assert ( - self.min_seq_len >= 1 - ), f"min_seq_len must be greater than or equal to 1, min_seq_len={options.min_seq_len}" + assert self.min_seq_len >= 1, ( + f"min_seq_len must be greater than or equal to 1, min_seq_len={options.min_seq_len}" + ) self.eos_threshold = options.eos_threshold @@ -145,9 +145,9 @@ def __call__( self.prompt_seq_lens = None else: self.prompt_seq_lens = prompt_padding_mask.seq_lens - assert ( - self.prompt_seq_lens is not None - ), "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + assert self.prompt_seq_lens is not None, ( + "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + ) self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item()) # Keep the materialized mask @@ -158,17 +158,17 @@ def __call__( # Make sure we do not accidentally set a max_gen_len that exceeds # the generator's model capability - assert ( - max_gen_len <= self.max_seq_len - ), f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) self.max_gen_len = max_gen_len if not min_gen_len: min_gen_len = self.min_seq_len - assert ( - min_gen_len > 0 - ), f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) self.min_gen_len = min_gen_len if temperature == 0.0: @@ -303,9 +303,9 @@ def finalize_step( # Ignore prompt positions between min-max prompt_len must_keep_going = None if self.step_nr < self.max_prompt_len: - assert ( - self.prompt_padding_mask is not None - ), f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + assert self.prompt_padding_mask is not None, ( + f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + ) mask = self.prompt_padding_mask[:, self.step_nr] model_last_output[mask] = self.seqs[mask, self.step_nr] must_keep_going = mask diff --git a/lcm/inference/lcm/scorer.py b/lcm/inference/lcm/scorer.py index e7669bf..15cbd15 100644 --- a/lcm/inference/lcm/scorer.py +++ b/lcm/inference/lcm/scorer.py @@ -62,9 +62,9 @@ def __call__( # type: ignore ) else: self.text_seq_lens = text_padding_mask.seq_lens - assert ( - self.text_seq_lens is not None - ), "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + assert self.text_seq_lens is not None, ( + "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + ) # Keep the materialized mask self.text_padding_mask = text_padding_mask.materialize() @@ -77,17 +77,17 @@ def __call__( # type: ignore # Make sure we do not accidentally set a max_gen_len that exceeds # the generator's model capability - assert ( - max_gen_len <= self.max_seq_len - ), f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) self.max_gen_len = max_gen_len if not min_gen_len: min_gen_len = self.min_seq_len - assert ( - min_gen_len > 0 - ), f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) self.min_gen_len = min_gen_len if temperature == 0.0: diff --git a/lcm/inference/two_tower_diffusion_lcm/generator.py b/lcm/inference/two_tower_diffusion_lcm/generator.py index 6051ef2..62487d8 100644 --- a/lcm/inference/two_tower_diffusion_lcm/generator.py +++ b/lcm/inference/two_tower_diffusion_lcm/generator.py @@ -86,9 +86,9 @@ def __init__( ) -> None: super().__init__(model, options, eos_vec) - assert isinstance( - self.model, TwoTowerDiffusionLCModel - ), "The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM" + assert isinstance(self.model, TwoTowerDiffusionLCModel), ( + "The TwoTowerDiffusionLCMGenerator expects a Diffusion LCM" + ) logger.info( f"Setting up the model with decoding_options: {options} -- {type(options)}" @@ -133,9 +133,9 @@ def __call__( self.prompt_seq_lens = None else: self.prompt_seq_lens = prompt_padding_mask.seq_lens - assert ( - self.prompt_seq_lens is not None - ), "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + assert self.prompt_seq_lens is not None, ( + "Expecting a valid `self.prompt_seq_lens` Tensor, found `None`" + ) self.min_prompt_len = int(torch.min(self.prompt_seq_lens, dim=0)[0].item()) # Keep the materialized mask @@ -146,17 +146,17 @@ def __call__( # Make sure we do not accidentally set a max_gen_len that exceeds # the generator's model capability - assert ( - max_gen_len <= self.max_seq_len - ), f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) self.max_gen_len = max_gen_len if not min_gen_len: min_gen_len = self.min_seq_len - assert ( - min_gen_len > 0 - ), f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) self.min_gen_len = min_gen_len if temperature == 0.0: @@ -223,9 +223,9 @@ def state_bag_reorder(self, new_order: torch.Tensor) -> None: def prefill(self, **kwargs) -> None: """encode the prefix with the context encoder""" - assert ( - self.context_state_bag is not None - ), "Expecting a context state bag to prefill" + assert self.context_state_bag is not None, ( + "Expecting a context state bag to prefill" + ) context: EmbeddingsBatch @@ -320,9 +320,9 @@ def finalize_step( # Ignore prompt positions between min-max prompt_len must_keep_going = None if self.step_nr < self.max_prompt_len: - assert ( - self.prompt_padding_mask is not None - ), f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + assert self.prompt_padding_mask is not None, ( + f"If self.prompt_padding_mas is None, then self.step_nr should start from self.max_prompt_len={self.max_prompt_len} - currently self.step_nr = {self.step_nr}" + ) mask = self.prompt_padding_mask[:, self.step_nr] model_last_output[mask] = self.seqs[mask, self.step_nr] must_keep_going = mask diff --git a/lcm/inference/two_tower_diffusion_lcm/scorer.py b/lcm/inference/two_tower_diffusion_lcm/scorer.py index 703e48b..bad99d7 100644 --- a/lcm/inference/two_tower_diffusion_lcm/scorer.py +++ b/lcm/inference/two_tower_diffusion_lcm/scorer.py @@ -75,9 +75,9 @@ def __call__( # type: ignore ) else: self.text_seq_lens = text_padding_mask.seq_lens - assert ( - self.text_seq_lens is not None - ), "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + assert self.text_seq_lens is not None, ( + "Expecting a valid `self.text_seq_lens` Tensor, found `None`" + ) # Keep the materialized mask self.text_padding_mask = text_padding_mask.materialize() @@ -90,9 +90,9 @@ def __call__( # type: ignore # Make sure we do not accidentally set a max_gen_len that exceeds # the generator's model capability - assert ( - max_gen_len <= self.max_seq_len - ), f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + assert max_gen_len <= self.max_seq_len, ( + f"Generator can generate up to {self.max_seq_len} sequences, max_gen_len={max_gen_len}" + ) self.max_gen_len = max_gen_len if not min_gen_len: @@ -100,9 +100,9 @@ def __call__( # type: ignore assert min_gen_len is not None, "A `min_gen_len` is required" - assert ( - min_gen_len > 0 - ), f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + assert min_gen_len > 0, ( + f"min_gen_len must be greater than or equal to 1, min_gen_len={min_gen_len}" + ) self.min_gen_len = min_gen_len @@ -163,9 +163,9 @@ def state_bag_reorder(self, new_order: torch.Tensor) -> None: def prefill(self, **kwargs) -> None: """encode the prefix with the context encoder""" - assert ( - self.context_state_bag is not None - ), "Expecting a context state bag to prefill" + assert self.context_state_bag is not None, ( + "Expecting a context state bag to prefill" + ) context: EmbeddingsBatch diff --git a/lcm/models/two_tower_diffusion_lcm/builder.py b/lcm/models/two_tower_diffusion_lcm/builder.py index 487d2cc..0572ad4 100644 --- a/lcm/models/two_tower_diffusion_lcm/builder.py +++ b/lcm/models/two_tower_diffusion_lcm/builder.py @@ -243,9 +243,9 @@ def prep_for_denoising(self, decoding_options): def sample_initial_noise_vectors(self, batch_size: int): # Check that we called `prep_for_denoising`: - assert hasattr( - self, "clip_noise" - ), "The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`" + assert hasattr(self, "clip_noise"), ( + "The model is not properly set for decoding, make sure to call `model.prep_for_denoising()`" + ) # Sample a noise vector for next embedding prediction latents = torch.randn( @@ -269,9 +269,9 @@ def predict_next_sentence( # type: ignore context_state_bag: Optional[LCMIncrementalStateBag] = None, **kwargs, ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: - assert ( - context_state_bag is not None - ), "Expected a state_bag to incrementally encode the context" + assert context_state_bag is not None, ( + "Expected a state_bag to incrementally encode the context" + ) if self.do_classifier_free_guidance: logger.debug("Running inference with CF-guidance...") @@ -347,9 +347,9 @@ def predict_next_sentence_with_cf_guidance( # type: ignore context_state_bag: Optional[LCMIncrementalStateBag] = None, **kwargs, ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch]: - assert ( - context_state_bag is not None - ), "Expected a state_bag to incrementally encode the context" + assert context_state_bag is not None, ( + "Expected a state_bag to incrementally encode the context" + ) # Normalize the input embeddings if we're expected to # normalize outside of the model's forward pass @@ -533,9 +533,9 @@ def build_model(self) -> TwoTowerDiffusionLCModel: """Build a model.""" sonar_normalizer = self.build_sonar_normalizer() - assert ( - sonar_normalizer is not None - ), "TwoTowerDiffusionLCModel expects a `sonar_normalizer`" + assert sonar_normalizer is not None, ( + "TwoTowerDiffusionLCModel expects a `sonar_normalizer`" + ) # the context encoder encoder_frontend = self.build_frontend() diff --git a/lcm/nn/denoisers/factory.py b/lcm/nn/denoisers/factory.py index 331ac12..9651684 100644 --- a/lcm/nn/denoisers/factory.py +++ b/lcm/nn/denoisers/factory.py @@ -160,9 +160,9 @@ def build_model(self) -> LCMDenoiser: def build_layer(self) -> LCMDenoiserLayer: """Build a Transformer decoder layer based on the provided config.""" - assert isinstance( - self.config, DenoiserConfig - ), "Expecting a DenoiserConfig in the DenoiserTransformerFactory" + assert isinstance(self.config, DenoiserConfig), ( + "Expecting a DenoiserConfig in the DenoiserTransformerFactory" + ) self_attn = self.build_attention() diff --git a/lcm/nn/denoisers/lcm_denoiser.py b/lcm/nn/denoisers/lcm_denoiser.py index c6b613c..6129f0d 100644 --- a/lcm/nn/denoisers/lcm_denoiser.py +++ b/lcm/nn/denoisers/lcm_denoiser.py @@ -182,13 +182,13 @@ def forward( """ emb_timesteps = self.embed_time(diffusion_timesteps) - assert ( - conditioning_variables is not None - ), "Expected conditioning_variables, found None" + assert conditioning_variables is not None, ( + "Expected conditioning_variables, found None" + ) - assert ( - conditioning_variables is not None - ), "Mypy - Expecting non-None conditioning_variables" + assert conditioning_variables is not None, ( + "Mypy - Expecting non-None conditioning_variables" + ) conditioning_variables = torch.cat( [ @@ -438,9 +438,9 @@ def _forward_self_attn( ) -> Tensor: residual = seqs - assert ( - self.norm_order != TransformerNormOrder.POST - ), "DiT AdaLN expect pre-normalization" + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expect pre-normalization" + ) if self.norm_order != TransformerNormOrder.POST: seqs = self.self_attn_layer_norm(seqs) @@ -489,9 +489,9 @@ def _forward_cross_attention( residual = seqs - assert ( - self.norm_order != TransformerNormOrder.POST - ), "DiT AdaLN expect pre-normalization" + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expect pre-normalization" + ) if self.norm_order != TransformerNormOrder.POST: seqs = cast(LayerNorm, self.cross_attention_layer_norm)(seqs) @@ -521,9 +521,9 @@ def _forward_cross_attention( return seqs def _forward_ffn(self, seqs: Tensor, modulators: Tensor) -> Tensor: - assert ( - self.norm_order != TransformerNormOrder.POST - ), "DiT AdaLN expects pre-normalization" + assert self.norm_order != TransformerNormOrder.POST, ( + "DiT AdaLN expects pre-normalization" + ) residual = seqs if self.norm_order != TransformerNormOrder.POST: diff --git a/lcm/nn/transformer/factory.py b/lcm/nn/transformer/factory.py index c60f5a3..0bcfa08 100644 --- a/lcm/nn/transformer/factory.py +++ b/lcm/nn/transformer/factory.py @@ -175,9 +175,9 @@ def build_layer(self) -> TransformerDecoderLayer: ) # reset residual_scale if layer.residual_scale is not None: - assert ( - self.config.scale_residual is not None - ), f"Layer has a resiudal scale but scale={self.config.scale_residual}" + assert self.config.scale_residual is not None, ( + f"Layer has a resiudal scale but scale={self.config.scale_residual}" + ) torch.nn.init.constant_(layer.residual_scale, self.config.scale_residual) logger.info( f"Initializing the residual scale at {self.config.scale_residual}" diff --git a/lcm/train/common.py b/lcm/train/common.py index 625ffd1..cbcba35 100644 --- a/lcm/train/common.py +++ b/lcm/train/common.py @@ -33,9 +33,9 @@ def _parse_training_config(train_config: DictConfig): try: trainer_obj = hydra.utils.get_object(trainer_cls_or_func) sign = signature(trainer_obj) - assert ( - len(sign.parameters) == 1 and "config" in sign.parameters - ), f'{trainer_cls_or_func} should take a single argument called "config"' + assert len(sign.parameters) == 1 and "config" in sign.parameters, ( + f'{trainer_cls_or_func} should take a single argument called "config"' + ) param_type = sign.parameters["config"].annotation OmegaConf.resolve(train_config) diff --git a/lcm/train/criterion.py b/lcm/train/criterion.py index b22e149..b96a498 100644 --- a/lcm/train/criterion.py +++ b/lcm/train/criterion.py @@ -91,9 +91,9 @@ def register(cls, name: str) -> Callable: """decorator for adding criterions to the registry""" def inner_wrapper(wrapped_class: Criterion) -> Callable: - assert ( - name not in cls.registry - ), f"{name} is already register as a criterion" + assert name not in cls.registry, ( + f"{name} is already register as a criterion" + ) cls.registry[name] = wrapped_class return wrapped_class diff --git a/lcm/train/lcm/criterion.py b/lcm/train/lcm/criterion.py index 4c84f1a..90f3673 100644 --- a/lcm/train/lcm/criterion.py +++ b/lcm/train/lcm/criterion.py @@ -66,9 +66,9 @@ def compute_standard_mse( ) if normalizer is not None: - assert hasattr( - normalizer, "denormalize" - ), "The provided normalizer has not method `denormalize`" + assert hasattr(normalizer, "denormalize"), ( + "The provided normalizer has not method `denormalize`" + ) flattened_predictions = normalizer.denormalize(flattened_predictions) flattened_target = normalizer.denormalize(flattened_target) diff --git a/lcm/train/optim.py b/lcm/train/optim.py index 0a608a1..e0df481 100644 --- a/lcm/train/optim.py +++ b/lcm/train/optim.py @@ -29,16 +29,15 @@ def build_lr_scheduler( stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5), schedule: str = "myle", ) -> AbstractLRScheduler: - assert ( - schedule - in [ - "noop", - "myle", - "cosine", - "wsd", - "polynomial", - ] - ), f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported" + assert schedule in [ + "noop", + "myle", + "cosine", + "wsd", + "polynomial", + ], ( + f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported" + ) assert lr > 0, "The learning reate should be strictly positive" @@ -66,14 +65,14 @@ def build_lr_scheduler( ) elif schedule == "wsd": - assert ( - lr > start_lr - ), f"the starting learning rate {start_lr} should be lesser than the main lr {lr}" + assert lr > start_lr, ( + f"the starting learning rate {start_lr} should be lesser than the main lr {lr}" + ) start_lr_scale = start_lr / lr - assert ( - lr > final_lr - ), f"the final learning rate {final_lr} should be lesser than the main lr {lr}" + assert lr > final_lr, ( + f"the final learning rate {final_lr} should be lesser than the main lr {lr}" + ) final_lr_scale = final_lr / lr lr_scheduler = TriStageLR( diff --git a/lcm/train/trainer.py b/lcm/train/trainer.py index 90667e8..22d7ecb 100644 --- a/lcm/train/trainer.py +++ b/lcm/train/trainer.py @@ -1060,9 +1060,9 @@ def _log_nan_loss(self, exc): class TrainerBuilder: def __init__(self, config: TrainingConfig): - assert ( - config.save_model_every_n_steps % config.checkpoint_every_n_steps == 0 - ), f"save_model_every_n_steps={config.save_model_every_n_steps} for saving consolidated models should be a multiplier of checkpoint_every_n_steps={config.checkpoint_every_n_steps}" + assert config.save_model_every_n_steps % config.checkpoint_every_n_steps == 0, ( + f"save_model_every_n_steps={config.save_model_every_n_steps} for saving consolidated models should be a multiplier of checkpoint_every_n_steps={config.checkpoint_every_n_steps}" + ) self.config = config @@ -1132,9 +1132,9 @@ def create_model_config(self, set_finetune_flag: bool = False): here inferred from the use of `model_config_or_name` """ if self.config.model_config_or_name is not None: - assert ( - self.config.model_arch is None - ), "We cannot set both `model_config_or_name` and `model_arch`" + assert self.config.model_arch is None, ( + "We cannot set both `model_config_or_name` and `model_arch`" + ) if isinstance(self.config.model_config_or_name, str): # The config of a registered model i.e. we're finetuning @@ -1180,7 +1180,9 @@ def create_model_config(self, set_finetune_flag: bool = False): elif self.config.model_arch is not None: assert ( self.config.model_arch in self.model_config_loader._arch_configs.names() - ), f"Could not recognise {self.config.model_arch} as a registered architecture " + ), ( + f"Could not recognise {self.config.model_arch} as a registered architecture " + ) logger.info( f"Creating a model from registered arch {self.config.model_arch}" @@ -1392,8 +1394,10 @@ def maybe_freeze_parameters(self, model): def _setup_additional_logging(self): if self.config.debug: - assert self.config.log_folder is not None, "Missing log_folder, \ + assert self.config.log_folder is not None, ( + "Missing log_folder, \ make sure the log_folder is properly set in the training config" + ) setup_additional_logging(log_folder=self.config.log_folder) @property diff --git a/lcm/train/two_tower_diffusion_lcm/criterion.py b/lcm/train/two_tower_diffusion_lcm/criterion.py index 2d10976..d037536 100644 --- a/lcm/train/two_tower_diffusion_lcm/criterion.py +++ b/lcm/train/two_tower_diffusion_lcm/criterion.py @@ -52,9 +52,9 @@ def __init__( style: LCMStyle = LCMStyle.UNSUPERVISED, ): super().__init__(config, model, style) - assert hasattr( - self.base_model, "noise_scheduler" - ), "Expecting the diffusion model to have a `noise_scheduler`" + assert hasattr(self.base_model, "noise_scheduler"), ( + "Expecting the diffusion model to have a `noise_scheduler`" + ) self.noise_scheduler = self.base_model.noise_scheduler self.prediction_type = self.noise_scheduler.prediction_type @@ -71,9 +71,9 @@ def __init__( f"trained_with_cf_guidance={self.trained_with_cf_guidance}", ) - assert ( - self.normalize_in_criterion - ), "We only support `normalize_in_criterion = True` in the diffusion criterions" + assert self.normalize_in_criterion, ( + "We only support `normalize_in_criterion = True` in the diffusion criterions" + ) self.summands.append("unnormalized_reconstruction_loss") @@ -378,9 +378,9 @@ def prepare_input_and_mask( # Prepare the input as in MSE LCM input_embeddings = batch.prepare_input(style=self.style) - assert ( - input_embeddings.source_lengths is not None - ), "Missing source lengths needed for the two-tower supervised fintuning" + assert input_embeddings.source_lengths is not None, ( + "Missing source lengths needed for the two-tower supervised fintuning" + ) target_embeddings = EmbeddingsBatch(*pad_seqs(batch.target)) # type: ignore diff --git a/lcm/utils/card_utils.py b/lcm/utils/card_utils.py index ac2a63c..8ab337e 100644 --- a/lcm/utils/card_utils.py +++ b/lcm/utils/card_utils.py @@ -147,9 +147,9 @@ def create_model_card_from_training_folder( model_config = training_config.model_config_or_name cp_fn = checkpoint_manager._checkpoint_dir / f"step_{step_nr}" / "model.pt" - assert ( - cp_fn - ), f"Checkpoint manager could not extract checkpoint path for step {step_nr}." + assert cp_fn, ( + f"Checkpoint manager could not extract checkpoint path for step {step_nr}." + ) # TODO: deal with the fine-tuning case, where model_config is a string if isinstance(model_config, str): parent_card = default_asset_store.retrieve_card(model_config) diff --git a/lcm/utils/logging.py b/lcm/utils/logging.py index 625bd46..66f550b 100644 --- a/lcm/utils/logging.py +++ b/lcm/utils/logging.py @@ -31,9 +31,9 @@ def log_git_status( repo: str = "lcm", tolerate_uncommitted: bool = False, ) -> str: - assert ( - repo in LCM_REPOS - ), f"Only the LCM core repos ({LCM_REPOS}) are supported in `log_git_status`" + assert repo in LCM_REPOS, ( + f"Only the LCM core repos ({LCM_REPOS}) are supported in `log_git_status`" + ) repo_path = os.path.dirname(globals()[repo].__file__) diff --git a/lcm/utils/model_type_registry.py b/lcm/utils/model_type_registry.py index 8d13e5e..974f7eb 100644 --- a/lcm/utils/model_type_registry.py +++ b/lcm/utils/model_type_registry.py @@ -39,9 +39,9 @@ def register(self, model_type_config: ModelTypeConfig) -> None: The factory to construct model configurations. """ model_type = model_type_config.model_type - assert ( - model_type - ), "To register a model type, the model_type parameter should be non-empty." + assert model_type, ( + "To register a model type, the model_type parameter should be non-empty." + ) if model_type in self._configs: raise ValueError( f"`model_type` must be a unique model type name, but '{model_type}' is already registered." diff --git a/tests/units/evaluation/test_cli.py b/tests/units/evaluation/test_cli.py index 664685e..ee9856f 100644 --- a/tests/units/evaluation/test_cli.py +++ b/tests/units/evaluation/test_cli.py @@ -66,9 +66,9 @@ def test_dynamic_prompt(tmp_path, simple_json_dataset, monkeypatch): ) ) for prompt, result in zip(default_prompts, results): - assert ( - result["text_prompts"] == prompt - ), f"Not match: {result['text_prompts']} != {prompt}" + assert result["text_prompts"] == prompt, ( + f"Not match: {result['text_prompts']} != {prompt}" + ) # Custom prompt with prefix and suffix with monkeypatch.context() as m2: m2.setattr( @@ -100,9 +100,9 @@ def test_dynamic_prompt(tmp_path, simple_json_dataset, monkeypatch): ) ) for prompt, result in zip(custom_prompts, results): - assert ( - result["text_prompts"] == prompt - ), f"Not match: {result['text_prompts']} != {prompt}" + assert result["text_prompts"] == prompt, ( + f"Not match: {result['text_prompts']} != {prompt}" + ) # Custom prompt with complex sequences of text with monkeypatch.context() as m3: @@ -141,6 +141,6 @@ def test_dynamic_prompt(tmp_path, simple_json_dataset, monkeypatch): ) ) for prompt, result in zip(custom_prompts, results): - assert ( - result["text_prompts"] == prompt - ), f"Not match: {result['text_prompts']} != {prompt}" + assert result["text_prompts"] == prompt, ( + f"Not match: {result['text_prompts']} != {prompt}" + ) diff --git a/tests/units/evaluation/test_generation_tasks.py b/tests/units/evaluation/test_generation_tasks.py index 82bdebf..884e4d0 100644 --- a/tests/units/evaluation/test_generation_tasks.py +++ b/tests/units/evaluation/test_generation_tasks.py @@ -209,6 +209,6 @@ def test_run_task_with_dynamic_predictor(simple_data_config): data_loading_config=eval_dlc, dataset_config=dsc, ) - assert ( - "m1" in avg_result.metrics and avg_result.metrics["m1"].avg == 0.0 - ), avg_result.metrics # type: ignore + assert "m1" in avg_result.metrics and avg_result.metrics["m1"].avg == 0.0, ( + avg_result.metrics + ) # type: ignore diff --git a/tests/units/test_recipes.py b/tests/units/test_recipes.py index d15c75a..1b99a6e 100644 --- a/tests/units/test_recipes.py +++ b/tests/units/test_recipes.py @@ -52,9 +52,9 @@ def test_train_recipes(monkeypatch, conf_name, tmp_path, group="train"): "++trainer.use_fsdp=false", ], ) - assert isinstance( - config, DictConfig - ), f"+{group}={conf_name} expect dict-type config, get {type(config)}." + assert isinstance(config, DictConfig), ( + f"+{group}={conf_name} expect dict-type config, get {type(config)}." + ) try: trainer = get_trainer(config.trainer) @@ -79,9 +79,9 @@ def test_train_recipes(monkeypatch, conf_name, tmp_path, group="train"): else: raise - assert isinstance( - trainer, Trainer - ), f"+{group}={conf_name} Error parsing recipe." + assert isinstance(trainer, Trainer), ( + f"+{group}={conf_name} Error parsing recipe." + ) def find_eval_recipes(): diff --git a/tests/units/training/test_toy_task_trainer.py b/tests/units/training/test_toy_task_trainer.py index cfa7e24..6955201 100644 --- a/tests/units/training/test_toy_task_trainer.py +++ b/tests/units/training/test_toy_task_trainer.py @@ -100,9 +100,9 @@ def test_toy_mse_training(tmp_path, simple_train_dataset, simple_validation_data ), f"{param_name} differs in checkpoint!" # Testing that the model card has been created - assert ( - train_dirname / "model_card.yaml" - ).exists(), f"The file {train_dirname}/model_card.yaml does not exist" + assert (train_dirname / "model_card.yaml").exists(), ( + f"The file {train_dirname}/model_card.yaml does not exist" + ) # Testing that the model card can be used to load the model correctly card = trainer.create_model_card_for_last_checkpoint()