Skip to content

Commit

Permalink
Merge pull request #20 from facebookresearch/eval/fix_ci
Browse files Browse the repository at this point in the history
Fix CI scripts
  • Loading branch information
antoine-tran authored Jan 15, 2025
2 parents b185028 + 16f2244 commit 67b7ccc
Show file tree
Hide file tree
Showing 39 changed files with 269 additions and 267 deletions.
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ runs:
- name: "Install libsndfile"
shell: bash
run: |
sudo apt-get install libsndfile1 python3-distutils
sudo apt-get install libsndfile1
6 changes: 3 additions & 3 deletions lcm/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def pipeline(self) -> DataPipeline:
self._pipeline = self.builder_func(
self.datasets, self.data_config, gang_rank, world_size
)
assert (
self._pipeline
), f"Cannot build data pipeline from config {self.data_config}"
assert self._pipeline, (
f"Cannot build data pipeline from config {self.data_config}"
)
return self._pipeline

def destroy(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions lcm/datasets/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def __post_init__(self):

length = len(self.source)

assert (
(self.target is None) or (len(self.target) == length)
), f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
assert (self.target is None) or (len(self.target) == length), (
f"all elements in LCMInput should be of the same length, got {len(self.target)} and {length}"
)

def __len__(self) -> int:
return len(self.source)
Expand Down Expand Up @@ -296,9 +296,9 @@ def prepare_input(
)

elif style == LCMStyle.SUPERVISED:
assert (
self.target is not None
), "Missing target embeddings for a supervised batch"
assert self.target is not None, (
"Missing target embeddings for a supervised batch"
)
return get_embeddings_sequence(
src_seqs=self.source,
tgt_seqs=self.target,
Expand Down
6 changes: 3 additions & 3 deletions lcm/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
else:
embs = None
outputs[key] = embs
assert (
outputs["source"] is not None
), "LCMDataLoader requires `source` sequences to be present in batches"
assert outputs["source"] is not None, (
"LCMDataLoader requires `source` sequences to be present in batches"
)
return LCMInput(**outputs)

def iterate_batches(self) -> Iterator[LCMInput]:
Expand Down
48 changes: 24 additions & 24 deletions lcm/datasets/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def build_dataload_pipeline(
self, rank: int = 0, world_size: int = 1
) -> DataPipelineBuilder:
if world_size > 1:
assert (
self.loading_config.seed is not None
), "for distributed training with `world_size` > 1, `seed` should be set !"
assert self.loading_config.seed is not None, (
"for distributed training with `world_size` > 1, `seed` should be set !"
)
if self.is_validation:
self.set_validation_params(world_size)

Expand Down Expand Up @@ -321,12 +321,12 @@ def create_on_the_fly_columns(
self, pipeline: DataPipelineBuilder
) -> DataPipelineBuilder:
if self.dataset_config.source_sequences is not None:
assert (
self.dataset_config.source_column is not None
), f"Expected a source_column - found {self.dataset_config.source_column}"
assert (
self.dataset_config.source_text_column is not None
), f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
assert self.dataset_config.source_column is not None, (
f"Expected a source_column - found {self.dataset_config.source_column}"
)
assert self.dataset_config.source_text_column is not None, (
f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
)

pipeline = pipeline.map(
partial(
Expand All @@ -338,12 +338,12 @@ def create_on_the_fly_columns(
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
)
if self.dataset_config.target_sequences is not None:
assert (
self.dataset_config.target_column is not None
), f"Expected a target_column, found {self.dataset_config.target_column}"
assert (
self.dataset_config.target_text_column is not None
), f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
assert self.dataset_config.target_column is not None, (
f"Expected a target_column, found {self.dataset_config.target_column}"
)
assert self.dataset_config.target_text_column is not None, (
f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
)

pipeline = pipeline.map(
partial(
Expand Down Expand Up @@ -426,9 +426,9 @@ def config_post_init(self) -> None:
)

if self.loading_config.even_sharding:
assert (
self.loading_config.seed is not None
), "`even_sharding` sharding requires to seed to be set"
assert self.loading_config.seed is not None, (
"`even_sharding` sharding requires to seed to be set"
)

if self.loading_config.max_tokens == 0:
self.loading_config.max_tokens = None
Expand Down Expand Up @@ -876,9 +876,9 @@ def add_min_max_sentence_len_in_doc_filter(
self.loading_config.max_sentence_len_in_doc
or self.loading_config.min_sentence_len_in_doc
):
assert (
self.dataset_config.source_text_column is not None
), f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
assert self.dataset_config.source_text_column is not None, (
f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
)

pipeline = pipeline.map(
partial(
Expand Down Expand Up @@ -962,9 +962,9 @@ def add_quality_score_filters(
if source_quality_range is None:
return pipeline

assert (
self.dataset_config.source_quality_column is not None
), f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
assert self.dataset_config.source_quality_column is not None, (
f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
)

pipeline = pipeline.map(
partial(
Expand Down
12 changes: 6 additions & 6 deletions lcm/datasets/parquet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ def build_batching_loop_over_one_table(
num_parallel_calls: int = 1,
) -> DataPipeline:
if max_tokens is not None:
assert (
length_column is not None
), "Need to provide a column to compute the number of tokens"
assert length_column is not None, (
"Need to provide a column to compute the number of tokens"
)

random_state = np.random.RandomState(seed)
if length_column is not None and len(length_column) > 0:
Expand Down Expand Up @@ -1109,9 +1109,9 @@ def get_row_group_level_metadata(
columns_to_exclude = set(["row_group_id", "num_rows", "total_byte_size"]) & set(
columns
)
assert (
len(columns_to_exclude) == 0
), f"names conflict, rename/remove : {columns_to_exclude}"
assert len(columns_to_exclude) == 0, (
f"names conflict, rename/remove : {columns_to_exclude}"
)

def get_one_row_group_stats(row_group):
metadata = row_group.metadata
Expand Down
18 changes: 9 additions & 9 deletions lcm/evaluation/arun.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
)

if iteration_value is not None:
assert (
isinstance(iteration_value, int) and self.config.nshards
), f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})"
assert (
self.config.data_loading
), f"Data loading is not specified: \n {self.config}"
assert isinstance(iteration_value, int) and self.config.nshards, (
f"Invalid shard value ({self.config.nshards}) or iteration value ({iteration_value})"
)
assert self.config.data_loading, (
f"Data loading is not specified: \n {self.config}"
)
self.config.data_loading.rank = iteration_value
self.config.data_loading.world_size = int(self.config.nshards)

Expand Down Expand Up @@ -194,9 +194,9 @@ async def schedule_task(
result = (metrics, result_file)

result_metrics, result_file = result
assert isinstance(
result_metrics, dict
), f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}"
assert isinstance(result_metrics, dict), (
f"Expected Tuple[Dict[str, AverageMetrics], str], get {type(result_metrics)}"
)

metrics = {}
cf = getattr(module.config, "confidence_level", None)
Expand Down
6 changes: 3 additions & 3 deletions lcm/evaluation/cli/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class CliConfig:

def __post_init__(self) -> None:
self.metric_log_dir = self.metric_log_dir or self.dump_dir
assert (
self.temperature >= 0.0
), f"Expect non-zero temperature, get {self.temperature}"
assert self.temperature >= 0.0, (
f"Expect non-zero temperature, get {self.temperature}"
)
if self.temperature == 0:
self.top_p = 0
self.top_k = 0
Expand Down
6 changes: 3 additions & 3 deletions lcm/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def get_scorer(
if "outputs" in defaults:
output_columns = defaults["outputs"].default
else:
assert (
config.model_name
), f"Cannot resolve output name for the scorer type {scorer_cls}"
assert config.model_name, (
f"Cannot resolve output name for the scorer type {scorer_cls}"
)
output_columns = scorer_cls.default_outputs(config.model_name)

if isinstance(output_columns, str):
Expand Down
6 changes: 3 additions & 3 deletions lcm/evaluation/metrics/multilingual_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def translate(
) -> List[str]:
src_lang, tgt_lang = self.src_lang, self.tgt_lang
sent_translations = []
assert isinstance(
self.model, EncoderDecoderModel
), f"Unsupported type: {type(self.model)}"
assert isinstance(self.model, EncoderDecoderModel), (
f"Unsupported type: {type(self.model)}"
)
generator = BeamSearchSeq2SeqGenerator(
self.model, echo_prompt=True, max_seq_len=self.max_seq_len
)
Expand Down
12 changes: 6 additions & 6 deletions lcm/evaluation/metrics/sentence_fluency.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def score_texts(
bos_token = bos_token or getattr(self.tokenizer, "bos_token", "\n")
if eos_token != "":
eos_token = eos_token or getattr(self.tokenizer, "eos_token", "\n")
assert (
eos_token is not None and bos_token is not None
), "Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''"
assert eos_token is not None and bos_token is not None, (
"Expecting eos and bos tokens, for perplexity without any surrounding tokens, use eos_token='' and bos_token=''"
)
logger.info(
f"Computing perplexity with bos_token={repr(bos_token)} and eos_token={repr(eos_token)}"
)
Expand Down Expand Up @@ -340,9 +340,9 @@ def backtranslate(
translations = []
back_translations = []
losses = []
assert isinstance(
self.model, EncoderDecoderModel
), f"Unsupported type: {type(self.model)}"
assert isinstance(self.model, EncoderDecoderModel), (
f"Unsupported type: {type(self.model)}"
)
generator = BeamSearchSeq2SeqGenerator(
self.model, echo_prompt=True, max_seq_len=self.max_seq_len
)
Expand Down
6 changes: 3 additions & 3 deletions lcm/evaluation/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def build_predictor(
if isinstance(predictor_config, PredictorConfig):
config_cls = predictor_config.__class__
else:
assert (
predictor_type is not None
), f"Cannot infer predictor from config type {type(predictor_config)}"
assert predictor_type is not None, (
f"Cannot infer predictor from config type {type(predictor_config)}"
)
config_cls = get_config_cls(predictor_type)
predictor_config = promote_config(predictor_config, config_cls)

Expand Down
12 changes: 6 additions & 6 deletions lcm/evaluation/predictors/lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def __call__(
if max_gen_len is None:
max_gen_len = self.config.max_seq_len

assert isinstance(
prompts, Batched
), f"Expect sequence of prompts, get {type(prompts)}"
assert isinstance(prompts, Batched), (
f"Expect sequence of prompts, get {type(prompts)}"
)

# Extract the input embeddings
seqs: List[torch.Tensor] = []
Expand All @@ -226,9 +226,9 @@ def __call__(
)
seqs.append(prompt_embs)
else:
assert (
isinstance(prompts, list) and isinstance(prompts[0], torch.Tensor)
), f"Expect sonarized prompts in the form or List[torch.Tensor], get {type(prompts)}"
assert isinstance(prompts, list) and isinstance(prompts[0], torch.Tensor), (
f"Expect sonarized prompts in the form or List[torch.Tensor], get {type(prompts)}"
)
seqs = prompts

if max_prompt_len:
Expand Down
6 changes: 3 additions & 3 deletions lcm/evaluation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def run_task(
task_config = TaskRegistry.get_config(run_config.task_name, **run_config.params)

data_loader_type = TaskRegistry.get_dataloader_type(run_config.task_name)
assert isinstance(
run_config.data_loading, EvaluationDataLoadingConfig
), "data loading not specified"
assert isinstance(run_config.data_loading, EvaluationDataLoadingConfig), (
"data loading not specified"
)
task = build_task(
task_config,
data_loading_config=run_config.data_loading,
Expand Down
24 changes: 12 additions & 12 deletions lcm/evaluation/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def load_few_shots(
few_shot_examples = config.few_shot_examples
if config.num_few_shot > 0:
if few_shot_examples is None:
assert (
config.few_shot_file
), f"Expect non-empty few_shot_file when few_shot = {config.num_few_shot}"
assert config.few_shot_file, (
f"Expect non-empty few_shot_file when few_shot = {config.num_few_shot}"
)
assert data_loader, "Expect non-empty data loader"
assert isinstance(
data_loader.data_config, EvaluationDataLoadingConfig
), f"unexpected data loading type: {type(data_loader.data_config)}"
assert isinstance(data_loader.data_config, EvaluationDataLoadingConfig), (
f"unexpected data loading type: {type(data_loader.data_config)}"
)
few_shot_data_loader = JSONTestDataLoader(
data_config=data_loader.data_config,
dataset=JSONDatasetConfig(file_path=config.few_shot_file),
Expand Down Expand Up @@ -487,9 +487,9 @@ def run( # type: ignore[override]
x["nlls"] = defaultdict(list)
for cix, (text, pred) in enumerate(zip(x["choice_texts"], preds)):
assert pred.tokens and pred.logprobs and pred.text_offsets, pred
assert isinstance(
pred.text, str
), "multiple texts output is not supported in LLM predictor"
assert isinstance(pred.text, str), (
"multiple texts output is not supported in LLM predictor"
)
logprobs = pred.logprobs[
text_index(pred.text, pred.text_offsets, text)
]
Expand All @@ -501,9 +501,9 @@ def run( # type: ignore[override]
assert len(preds) == 2 * len(x["choice_texts"])
compl = preds[cix + len(x["choice_texts"])]
assert compl.tokens and compl.logprobs and compl.text_offsets
assert isinstance(
compl.text, str
), "multiple texts output is not supported in LLM predictor"
assert isinstance(compl.text, str), (
"multiple texts output is not supported in LLM predictor"
)
slice = text_index(compl.text, compl.text_offsets, text)
nll_compl = -sum(logprobs) + sum(compl.logprobs[slice])
x["nlls"]["completion"].append(nll_compl)
Expand Down
9 changes: 4 additions & 5 deletions lcm/evaluation/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def string_format(template: str, skip_validation: bool = True, **kwargs: Any) ->
variables = [k[1] for k in string.Formatter().parse(template) if k[1]]
if not all(k in kwargs for k in variables):
raise ValueError(
f"Expected: {variables}, got: {sorted(kwargs)}.\n"
f"Template:\n{template}"
f"Expected: {variables}, got: {sorted(kwargs)}.\nTemplate:\n{template}"
)
# `Dict[Optional[str], typing.Any]`.
kwargs = {k: kwargs[k] for k in variables}
Expand Down Expand Up @@ -580,9 +579,9 @@ def flatten_dict(data: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
if isinstance(value, dict):
flat_dict.update(flatten_dict(value, f"{prefix}{key}/"))
elif isinstance(value, Tensor):
assert (
len(value.size()) == 0
), f"Only scalar tensor can be formatted, get {value}"
assert len(value.size()) == 0, (
f"Only scalar tensor can be formatted, get {value}"
)
flat_dict[f"{prefix}{key}"] = value.item()
elif isinstance(value, AverageMetric):
flat_dict[f"{prefix}{key}"] = value.value
Expand Down
Loading

0 comments on commit 67b7ccc

Please sign in to comment.