diff --git a/mltb2/files.py b/mltb2/files.py index 4eec687..03da345 100644 --- a/mltb2/files.py +++ b/mltb2/files.py @@ -150,10 +150,22 @@ def _write_lock_files(self, batch: Sequence[Dict[str, Any]]) -> None: (self._result_dir_path / f"{uuid}.lock").touch() self._own_lock_uuids.add(uuid) - def read_batch(self) -> Sequence[Dict[str, Any]]: - """Read the next batch of data.""" + def _get_remaining_data(self) -> List[Dict[str, Any]]: locked_or_done_uuids: Set[str] = self._get_locked_or_done_uuids() remaining_data = [d for d in self.data if d[self.uuid_name] not in locked_or_done_uuids] + return remaining_data + + def read_batch(self) -> Sequence[Dict[str, Any]]: + """Read the next batch of data.""" + remaining_data: List[Dict[str, Any]] = self._get_remaining_data() + + # if we think we are done, delete all lock files and check again + # this is because lock files might be orphaned + if len(remaining_data) == 0: + for lock_file_path in self._result_dir_path.glob("*.lock"): + lock_file_path.unlink(missing_ok=True) + remaining_data = self._get_remaining_data() + random.shuffle(remaining_data) next_batch_size = min(self.batch_size, len(remaining_data)) next_batch = remaining_data[:next_batch_size] diff --git a/pyproject.toml b/pyproject.toml index efa9c0f..bd6d9f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mltb2" -version = "1.0.0rc1" +version = "1.0.0rc2" description = "Machine Learning Toolbox 2" authors = ["PhilipMay "] readme = "README.md" diff --git a/tests/test_files.py b/tests/test_files.py index cdeabc8..14297aa 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -225,3 +225,32 @@ def test_FileBasedRestartableBatchDataProcessor_len(tmp_path): data=data, batch_size=10, uuid_name="uuid", result_dir=result_dir ) assert len(data_processor) == 77 + + +def test_FileBasedRestartableBatchDataProcessor_clear_lock_files(tmp_path): + result_dir = tmp_path.absolute() + batch_size = 10 + data = [{"uuid": str(uuid4()), "x": i} for i in range(100)] + data_processor = FileBasedRestartableBatchDataProcessor( + data=data, batch_size=batch_size, uuid_name="uuid", result_dir=result_dir + ) + + _ = data_processor.read_batch() # create empty lock files + + # process all data + while True: + _data = data_processor.read_batch() + if len(_data) == 0: + break + data_processor.save_batch(_data) + + del data_processor + processed_data = FileBasedRestartableBatchDataProcessor.load_data(result_dir) + + assert len(processed_data) == len(data) + for d in processed_data: + assert "uuid" in d + assert "x" in d + assert isinstance(d["uuid"], str) + assert isinstance(d["x"], int) + assert d["x"] < 100