diff --git a/argdantic/stores/base.py b/argdantic/stores/base.py index 94fa09c..8141485 100644 --- a/argdantic/stores/base.py +++ b/argdantic/stores/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Optional, Set, Union +from typing import Callable, Literal, Optional, Set, Union from pydantic import BaseModel from pydantic_settings import BaseSettings @@ -18,6 +18,7 @@ def __init__( self, path: Union[str, Path], *, + mode: Literal["python", "json"] = "python", encoding: str = "utf-8", include: Optional[Set[str]] = None, exclude: Optional[Set[str]] = None, @@ -27,6 +28,7 @@ def __init__( exclude_none: bool = False, ) -> None: self.path = Path(path) + self.mode = mode self.encoding = encoding self.include = include self.exclude = exclude diff --git a/argdantic/stores/json.py b/argdantic/stores/json.py index 2bfc1a6..c23b156 100644 --- a/argdantic/stores/json.py +++ b/argdantic/stores/json.py @@ -1,3 +1,6 @@ +from pathlib import Path +from typing import Optional, Set, Union + from pydantic_settings import BaseSettings from argdantic.stores.base import BaseSettingsStore @@ -9,6 +12,30 @@ class JsonSettingsStore(BaseSettingsStore): Orjson is used if available, otherwise the standard json module is used. """ + def __init__( + self, + path: Union[str, Path], + *, + encoding: str = "utf-8", + include: Optional[Set[str]] = None, + exclude: Optional[Set[str]] = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> None: + super().__init__( + path, + mode="json", + encoding=encoding, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + def __call__(self, settings: BaseSettings) -> None: with self.path.open("wb") as f: text = settings.model_dump_json( diff --git a/argdantic/stores/toml.py b/argdantic/stores/toml.py index 4e023fc..57bf69a 100644 --- a/argdantic/stores/toml.py +++ b/argdantic/stores/toml.py @@ -21,6 +21,7 @@ def __call__(self, settings: BaseSettings) -> None: with self.path.open("wb") as f: text = toml.dumps( settings.model_dump( + mode=self.mode, include=self.include, exclude=self.exclude, by_alias=self.by_alias, diff --git a/argdantic/stores/yaml.py b/argdantic/stores/yaml.py index 7fd891a..71b616e 100644 --- a/argdantic/stores/yaml.py +++ b/argdantic/stores/yaml.py @@ -21,6 +21,7 @@ def __call__(self, settings: BaseSettings) -> None: with self.path.open("w") as f: data = settings.model_dump( + mode=self.mode, include=self.include, exclude=self.exclude, by_alias=self.by_alias, diff --git a/tests/test_stores/test_json.py b/tests/test_stores/test_json.py index 5ae1ca8..13dcd69 100644 --- a/tests/test_stores/test_json.py +++ b/tests/test_stores/test_json.py @@ -79,3 +79,20 @@ def main(foo: str = "baz", bar: int = 42) -> None: result = runner.invoke(cli, []) assert result.exception is None assert result.return_value == ("baz", 42) + + +def test_parser_using_json_store_complex_data(tmp_path: Path, runner: CLIRunner) -> None: + from pathlib import Path + + from argdantic import ArgParser + + cli = ArgParser() + path = tmp_path / "settings.json" + + @cli.command(stores=[JsonSettingsStore(path)]) + def main(foo: Path = "baz", bar: int = 42) -> None: + return foo, bar + + result = runner.invoke(cli, []) + assert result.exception is None + assert result.return_value == ("baz", 42) diff --git a/tests/test_stores/test_toml.py b/tests/test_stores/test_toml.py index 8925901..22ed593 100644 --- a/tests/test_stores/test_toml.py +++ b/tests/test_stores/test_toml.py @@ -61,3 +61,19 @@ def main(foo: str = "baz", bar: int = 42) -> None: result = runner.invoke(parser, []) assert result.exception is None assert result.return_value == ("baz", 42) + + +def test_parser_using_toml_store_complex_data(tmp_path: Path, runner: CLIRunner) -> None: + from argdantic import ArgParser + from argdantic.stores.toml import TomlSettingsStore + + path = tmp_path / "settings.toml" + parser = ArgParser() + + @parser.command(stores=[TomlSettingsStore(path, mode="json")]) + def main(foo: Path = "baz", bar: int = 42) -> None: + return foo, bar + + result = runner.invoke(parser, []) + assert result.exception is None + assert result.return_value == ("baz", 42) diff --git a/tests/test_stores/test_yaml.py b/tests/test_stores/test_yaml.py index 01a7975..84d313f 100644 --- a/tests/test_stores/test_yaml.py +++ b/tests/test_stores/test_yaml.py @@ -68,3 +68,24 @@ def main(foo: str = "baz", bar: int = 42) -> None: assert result.exception is None assert result.return_value == ("qux", 42) assert result.return_value == ("qux", 42) + + +def test_parser_using_yaml_store_complex_data(tmp_path: Path, runner: CLIRunner) -> None: + from argdantic import ArgParser + from argdantic.stores.yaml import YamlSettingsStore + + path = tmp_path / "settings.yaml" + parser = ArgParser() + + @parser.command(stores=[YamlSettingsStore(path, mode="json")]) + def main(foo: Path = "baz", bar: int = 42) -> None: + return str(foo), bar + + result = runner.invoke(parser, []) + assert result.exception is None + assert result.return_value == ("baz", 42) + + result = runner.invoke(parser, ["--foo", "qux", "--bar", "24"]) + assert result.exception is None + print(result.return_value) + assert result.return_value == ("qux", 24)