Skip to content

Commit

Permalink
Save MD postprocessing filenames
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Feb 27, 2025
1 parent cd96197 commit ddd367a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 36 deletions.
112 changes: 76 additions & 36 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,16 @@ def output_files(self) -> None:
if self.minimize_kwargs["write_results"]
else None
)
output_files["rdfs"] = (
self._rdf_files
if self.post_process_kwargs.get("rdf_compute", False)
else None
)
output_files["vafs"] = (
self._vaf_files
if self.post_process_kwargs.get("vaf_compute", False)
else None
)

return output_files

Expand Down Expand Up @@ -739,6 +749,70 @@ def _correlations_file(self) -> str:
"""
return self._build_filename("cor.dat", self.param_prefix)

@property
def _rdf_files(self) -> tuple[Path]:
"""
Get RDF filenames.
Returns
-------
str
Filenames for RDF file.
"""
base_name = self.post_process_kwargs.get("rdf_output_file", None)
rdf_args = {
name: self.post_process_kwargs.get(key, default)
for name, (key, default) in (
("elements", ("rdf_elements", None)),
("by_elements", ("rdf_by_elements", False)),
)
}

if rdf_args["by_elements"]:
elements = (
tuple(sorted(set(self.struct.get_chemical_symbols())))
if rdf_args["elements"] is None
else rdf_args["elements"]
)

out_paths = tuple(
self._build_filename(
"rdf.dat",
self.param_prefix,
"_".join(element),
prefix_override=base_name,
)
for element in combinations_with_replacement(elements, 2)
)

else:
out_paths = (
self._build_filename(
"rdf.dat", self.param_prefix, prefix_override=base_name
),
)

return out_paths

@property
def _vaf_files(self) -> str:
"""
Define VAF filenames.
Returns
-------
str
Filenames for VAF files.
"""
file_names = self.post_process_kwargs.get("vaf_output_files", None)
if not isinstance(file_names, Sequence):
file_names = (file_names,)

return tuple(
self._build_filename("vaf.dat", self.param_prefix, filename=file_name)
for file_name in file_names
)

def _parse_correlations(self) -> None:
"""Parse correlation kwargs into Correlations."""
if self.correlation_kwargs:
Expand Down Expand Up @@ -970,7 +1044,6 @@ def _post_process(self) -> None:
ana = Analysis(data)

if self.post_process_kwargs.get("rdf_compute", False):
base_name = self.post_process_kwargs.get("rdf_output_file", None)
rdf_args = {
name: self.post_process_kwargs.get(key, default)
for name, (key, default) in (
Expand All @@ -987,45 +1060,12 @@ def _post_process(self) -> None:
)
rdf_args["index"] = slice_

if rdf_args["by_elements"]:
elements = (
tuple(sorted(set(data[0].get_chemical_symbols())))
if rdf_args["elements"] is None
else rdf_args["elements"]
)

out_paths = [
self._build_filename(
"rdf.dat",
self.param_prefix,
"_".join(element),
prefix_override=base_name,
)
for element in combinations_with_replacement(elements, 2)
]

else:
out_paths = [
self._build_filename(
"rdf.dat", self.param_prefix, prefix_override=base_name
)
]

compute_rdf(data, ana, filenames=out_paths, **rdf_args)
compute_rdf(data, ana, filenames=self._rdf_files, **rdf_args)

if self.post_process_kwargs.get("vaf_compute", False):
file_names = self.post_process_kwargs.get("vaf_output_files", None)
use_vel = self.post_process_kwargs.get("vaf_velocities", False)
fft = self.post_process_kwargs.get("vaf_fft", False)

if not isinstance(file_names, Sequence):
file_names = (file_names,)

out_paths = tuple(
self._build_filename("vaf.dat", self.param_prefix, filename=file_name)
for file_name in file_names
)

slice_ = (
self.post_process_kwargs.get("vaf_start", 0),
self.post_process_kwargs.get("vaf_stop", None),
Expand All @@ -1034,7 +1074,7 @@ def _post_process(self) -> None:

compute_vaf(
data,
out_paths,
self._vaf_files,
use_velocities=use_vel,
fft=fft,
index=slice_,
Expand Down
4 changes: 4 additions & 0 deletions janus_core/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def dict_paths_to_strs(dictionary: dict) -> None:
for key, value in dictionary.items():
if isinstance(value, dict):
dict_paths_to_strs(value)
elif isinstance(value, Sequence) and not isinstance(value, str):
dictionary[key] = [
str(path) if isinstance(path, Path) else path for path in value
]
elif isinstance(value, Path):
dictionary[key] = str(value)

Expand Down

0 comments on commit ddd367a

Please sign in to comment.