diff --git a/tests/test_io.py b/tests/test_io.py index fdb7e3157f6..3f1ad3416c0 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,12 +2,15 @@ from pathlib import Path from unittest import TestCase +from datetime import datetime import numpy as np import pandas as pd import pytest from topostats.io import ( read_yaml, + get_date_time, + write_config_with_comments, write_yaml, save_array, load_array, @@ -42,6 +45,12 @@ # pylint: disable=protected-access +def test_get_date_time() -> None: + """Test the fetching of a formatted date and time string.""" + + assert datetime.strptime(get_date_time(), "%Y-%m-%d %H:%M:%S") + + def test_read_yaml() -> None: """Test reading of YAML file.""" sample_config = read_yaml(RESOURCES / "test.yaml") @@ -49,6 +58,29 @@ def test_read_yaml() -> None: TestCase().assertDictEqual(sample_config, CONFIG) +def test_write_config_with_comments(tmp_path: Path) -> None: + """ + Test that the function write_yaml_with_comments successfully writes the default + config with comments to a file. + """ + + # Read default config with comments + with open(BASE_DIR / "topostats" / "default_config.yaml", encoding="utf-8") as f: + default_config_string = f.read() + + # Write default config with comments to file + write_config_with_comments(config=default_config_string, output_dir=tmp_path, filename="test_config_with_comments") + + # Read the written config + with open(tmp_path / "test_config_with_comments.yaml", encoding="utf-8") as f: + written_config = f.read() + + # Validate that the written config has comments in it + assert default_config_string in written_config + assert "Config file generated" in written_config + assert "For more information on configuration and how to use it" in written_config + + def test_write_yaml(tmp_path: Path) -> None: """Test writing of dictionary to YAML.""" write_yaml( diff --git a/topostats/io.py b/topostats/io.py index 51af6c6cb45..ece8992394c 100644 --- a/topostats/io.py +++ b/topostats/io.py @@ -20,6 +20,9 @@ LOGGER = logging.getLogger(LOGGER_NAME) +CONFIG_DOCUMENTATION_REFERENCE = """For more information on configuration and how to use it: +# https://afm-spm.github.io/TopoStats/main/configuration.html\n""" + # pylint: disable=broad-except @@ -46,6 +49,23 @@ def read_yaml(filename: Union[str, Path]) -> Dict: return {} +def get_date_time() -> str: + """ + Get a date and time for adding to generated files or logging. + + Parameters + ---------- + None + + Returns + ------- + str + A string of the current date and time, formatted appropriately. + """ + + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + def write_yaml( config: dict, output_dir: Union[str, Path], config_file: str = "config.yaml", header_message: str = None ) -> None: @@ -67,17 +87,12 @@ def write_yaml( # Revert PosixPath items to string config = path_to_str(config) config_yaml = yaml_load(yaml_dump(config)) - documentation_reference = ( - "For more information on configuration : https://afm-spm.github.io/TopoStats/main/configuration.html" - ) + if header_message: - config_yaml.yaml_set_start_comment( - f"{header_message} : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" + documentation_reference - ) + config_yaml.yaml_set_start_comment(f"{header_message} : {get_date_time()}\n" + CONFIG_DOCUMENTATION_REFERENCE) else: config_yaml.yaml_set_start_comment( - f"Configuration from TopoStats run completed : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" - + documentation_reference + f"Configuration from TopoStats run completed : {get_date_time()}\n" + CONFIG_DOCUMENTATION_REFERENCE ) with output_config.open("w") as f: try: @@ -86,6 +101,34 @@ def write_yaml( LOGGER.error(exception) +def write_config_with_comments(config: str, output_dir: Path, filename: str = "config.yaml") -> None: + """ + Create a config file, retaining the comments by writing it as a string + rather than using a yaml handling package. + + Parameters + ---------- + config: str + A string of the entire configuration file to be saved. + output_dir: Path + A pathlib path of where to create the config file. + filename: str + A name for the configuration file. Can have a ".yaml" on the end. + """ + + if ".yaml" not in filename and ".yml" not in filename: + create_config_path = output_dir / f"{filename}.yaml" + else: + create_config_path = output_dir / filename + + with open(f"{create_config_path}", "w", encoding="utf-8") as f: + f.write(f"# Config file generated {get_date_time()}\n") + f.write(f"# {CONFIG_DOCUMENTATION_REFERENCE}") + f.write(config) + LOGGER.info(f"A sample configuration has been written to : {str(create_config_path)}") + LOGGER.info(CONFIG_DOCUMENTATION_REFERENCE) + + def save_array(array: np.ndarray, outpath: Path, filename: str, array_type: str) -> None: """Save a Numpy array to disk. diff --git a/topostats/run_topostats.py b/topostats/run_topostats.py index 16d46d88a09..f9d4677c532 100644 --- a/topostats/run_topostats.py +++ b/topostats/run_topostats.py @@ -10,13 +10,21 @@ from multiprocessing import Pool from pprint import pformat import sys +from pathlib import Path import yaml import pandas as pd from tqdm import tqdm from topostats import __version__ -from topostats.io import find_files, read_yaml, save_folder_grainstats, write_yaml, LoadScans +from topostats.io import ( + find_files, + read_yaml, + save_folder_grainstats, + write_yaml, + write_config_with_comments, + LoadScans, +) from topostats.logs.logs import LOGGER_NAME from topostats.plotting import toposum from topostats.processing import check_run_steps, completion_message, process_scan @@ -142,11 +150,12 @@ def main(args=None): # Parse command line options, load config (or default) and update with command line options parser = create_parser() args = parser.parse_args() if args is None else parser.parse_args(args) - if args.config_file is not None: - config = read_yaml(args.config_file) + if args.config_file is None: + default_config = pkg_resources.open_text(__package__, "default_config.yaml").read() + config = yaml.safe_load(default_config) else: - default_config = pkg_resources.open_text(__package__, "default_config.yaml") - config = yaml.safe_load(default_config.read()) + config = read_yaml(args.config_file) + config = update_config(config, args) # Set logging level @@ -162,19 +171,10 @@ def main(args=None): validate_config(config, schema=DEFAULT_CONFIG_SCHEMA, config_type="YAML configuration file") # Write sample configuration if asked to do so and exit + if args.create_config_file and args.config_file: + raise ValueError("--create-config-file and --config cannot be used together.") if args.create_config_file: - write_yaml( - config, - output_dir="./", - config_file=args.create_config_file, - header_message="Sample configuration file auto-generated", - ) - LOGGER.info(f"A sample configuration has been written to : ./{args.create_config_file}") - LOGGER.info( - "Please refer to the documentation on how to use the configuration file : \n\n" - "https://afm-spm.github.io/TopoStats/usage.html#configuring-topostats\n" - "https://afm-spm.github.io/TopoStats/configuration.html" - ) + write_config_with_comments(config=default_config, output_dir=Path.cwd(), filename=args.create_config_file) sys.exit() # Create base output directory