diff --git a/janus_core/calculations/neb.py b/janus_core/calculations/neb.py index b726e3cb..36b18598 100644 --- a/janus_core/calculations/neb.py +++ b/janus_core/calculations/neb.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Sequence from copy import copy from typing import Any @@ -12,6 +12,7 @@ from ase.mep import DyNEB, NEBTools from ase.mep.neb import NEBOptimizer from matplotlib.figure import Figure +from pymatgen.io.ase import AseAtomsAdaptor from janus_core.calculations.base import BaseCalculation from janus_core.calculations.geom_opt import GeomOpt @@ -19,6 +20,7 @@ Architectures, ASEReadArgs, Devices, + Interpolators, OutputKwargs, PathLike, ) @@ -45,6 +47,12 @@ class NEB(BaseCalculation): final_struct_path Path of final structure for Nudged Elastic Band method. Required if `final_struct` is None. Default is None. + band_structs + Band of ASE Atoms images to optimize, skipping interpolation between the initial + and final structures. Requires interpolator to be None. + band_path + Path of band of images to optimize, skipping interpolation between the initial + and final structures. Requires interpolator to be None. arch MLIP architecture to use for Nudged Elastic Band method. Default is "mace_mp". device @@ -83,6 +91,8 @@ class NEB(BaseCalculation): {"k": 0.1, "climb": True, "method": "string"} for NEB, {"fmax": 0.1, "dynamic_relaxation": True, "climb": True, "scale_fmax": 1.2} for DynNEB, else {}. + interpolator + Choice of interpolation strategy. Default is "ase". interpolation_kwargs Keyword arguments to pass to neb_method.interpolate. Default is {"method": "idpp"}. @@ -115,6 +125,8 @@ def __init__( init_struct_path: PathLike | None = None, final_struct: Atoms | None = None, final_struct_path: PathLike | None = None, + band_structs: Sequence[Atoms] | None = None, + band_path: PathLike | None = None, arch: Architectures = "mace_mp", device: Devices = "cpu", model_path: PathLike | None = None, @@ -131,6 +143,7 @@ def __init__( write_results: bool = False, write_kwargs: OutputKwargs | None = None, neb_kwargs: dict[str, Any] | None = None, + interpolator: Interpolators | None = "ase", interpolation_kwargs: dict[str, Any] | None = None, neb_optimizer: Callable | str = NEBOptimizer, fmax: float = 0.1, @@ -159,6 +172,12 @@ def __init__( final_struct_path Path of final structure for Nudged Elastic Band method. Required if `final_struct` is None. Default is None. + band_structs + Band of ASE Atoms images to optimize, skipping interpolation between the + initial and final structures. Requires interpolator to be None. + band_path + Path of band of images to optimize, skipping interpolation between the + initial and final structures. Requires interpolator to be None. arch MLIP architecture to use for Nudged Elastic Band method. Default is "mace_mp". @@ -198,6 +217,8 @@ def __init__( {"k": 0.1, "climb": True, "method": "string"} for NEB, {"fmax": 0.1, "dynamic_relaxation": True, "climb": True, "scale_fmax": 1.2} for DynNEB, else {}. + interpolator + Choice of interpolation strategy. Default is "ase". interpolation_kwargs Keyword arguments to pass to neb_method.interpolate. Default is {"method": "idpp"}. @@ -240,6 +261,7 @@ def __init__( self.write_results = write_results self.write_kwargs = write_kwargs self.neb_kwargs = neb_kwargs + self.interpolator = interpolator self.interpolation_kwargs = interpolation_kwargs self.neb_optimizer = neb_optimizer self.fmax = fmax @@ -252,8 +274,22 @@ def __init__( if self.n_images <= 0 or not isinstance(self.n_images, int): raise ValueError("`n_images` must be an integer greater than 0.") - # Read last image by default - read_kwargs.setdefault("index", -1) + if band_structs or band_path: + if init_struct or init_struct_path or final_struct or final_struct_path: + raise ValueError( + "Band cannot be specified in combination with an initial or final " + "structure" + ) + if interpolator is not None: + raise ValueError("An interpolator cannot when specifying the band") + + init_struct = band_structs + init_struct_path = band_path + # Read all image by default for band + read_kwargs.setdefault("index", ":") + else: + # Read last image by default for init_struct + read_kwargs.setdefault("index", -1) # Initialise structures and logging super().__init__( @@ -264,7 +300,7 @@ def __init__( device=device, model_path=model_path, read_kwargs=read_kwargs, - sequence_allowed=False, + sequence_allowed=True, calc_kwargs=calc_kwargs, set_calc=set_calc, attach_logger=attach_logger, @@ -274,39 +310,43 @@ def __init__( file_prefix=file_prefix, ) - if not self.struct.calc: - raise ValueError("Please attach a calculator to `init_struct`.") - - # Use initial structure (path) for default file paths etc. - self.init_struct = self.struct - self.init_struct_path = self.struct_path - - self.final_struct = input_structs( - struct=final_struct, - struct_path=final_struct_path, - read_kwargs=read_kwargs, - sequence_allowed=False, - set_calc=False, - ) - self.final_struct_path = final_struct_path - self.final_struct.calc = copy(self.struct.calc) + if interpolator is not None: + if not isinstance(self.struct, Atoms): + raise ValueError("`init_struct` must be a single structure.") + if not self.struct.calc: + raise ValueError("Please attach a calculator to `init_struct`.") + + # Use initial structure (path) for default file paths etc. + self.init_struct = self.struct + self.init_struct_path = self.struct_path + + self.final_struct = input_structs( + struct=final_struct, + struct_path=final_struct_path, + read_kwargs=read_kwargs, + sequence_allowed=False, + set_calc=False, + ) + self.final_struct_path = final_struct_path + self.final_struct.calc = copy(self.struct.calc) + else: + if not isinstance(self.struct, Sequence): + raise ValueError("`images` must include multiple structures.") - # Set default interpolation method - interpolation_kwargs.setdefault("method", "idpp") + # Set default interpolation kwargs + if self.interpolator == "ase": + interpolation_kwargs.setdefault("method", "idpp") + if self.interpolator == "pymatgen": + interpolation_kwargs.setdefault("interpolate_lattices", False) + interpolation_kwargs.setdefault("autosort_tol", 0.5) # Set output file defaults self.plot_file = self._build_filename("neb-plot.svg").absolute() - self.images_write_kwargs = self.write_kwargs.copy() - self.images_write_kwargs["filename"] = self._build_filename( + self.write_kwargs["filename"] = self._build_filename( "neb-images.extxyz" ).absolute() - self.results_write_kwargs = self.write_kwargs.copy() - self.results_write_kwargs["filename"] = self._build_filename( - "neb-opt_images.extxyz" - ).absolute() - if self.minimize: set_minimize_logging( self.logger, self.minimize_kwargs, self.log_kwargs, track_carbon @@ -390,6 +430,41 @@ def plot(self) -> Figure | None: return fig + def set_interpolator(self) -> None: + """Interpolate images to create initial band.""" + match self.interpolator: + case "ase": + # Create band of images and attach calculators + self.images = [self.init_struct] + self.images += [self.init_struct.copy() for i in range(self.n_images)] + for image in self.images[1:]: + image.calc = copy(self.init_struct.calc) + self.images += [self.final_struct] + + self.neb = self.neb_method(self.images, **self.neb_kwargs) + self.neb.interpolate(**self.interpolation_kwargs) + + case "pymatgen": + # Create band of images and attach calculators + py_start_struct = AseAtomsAdaptor.get_structure(self.init_struct) + py_final_struct = AseAtomsAdaptor.get_structure(self.final_struct) + py_images = py_start_struct.interpolate( + py_final_struct, + nimages=self.n_images + 1, + **self.interpolation_kwargs, + ) + self.images = [image.to_ase_atoms() for image in py_images] + for image in self.images: + image.calc = copy(self.init_struct.calc) + + self.neb = self.neb_method(self.images, **self.neb_kwargs) + + case None: + # Band already created + pass + case _: + raise ValueError("Invalid interpolator selected") + def run(self) -> dict[str, float]: """ Run Nudged Elastic Band method. @@ -415,27 +490,19 @@ def run(self) -> dict[str, float]: ) GeomOpt(self.final_struct, **self.minimize_kwargs).run() - # Create band of images and attach calculators - self.images = [self.init_struct] - self.images += [self.init_struct.copy() for i in range(self.n_images)] - for image in self.images[1:]: - image.calc = copy(self.init_struct.calc) - self.images += [self.final_struct] + self.set_interpolator() + + optimizer = self.neb_optimizer(self.neb, **self.optimizer_kwargs) + optimizer.run(fmax=self.fmax) # Optionally write band images to file output_structs( images=self.images, struct_path=self.struct_path, write_results=self.write_images, - write_kwargs=self.images_write_kwargs, + write_kwargs=self.write_kwargs, ) - self.neb = self.neb_method(self.images, **self.neb_kwargs) - self.neb.interpolate(**self.interpolation_kwargs) - - optimizer = self.neb_optimizer(self.neb, **self.optimizer_kwargs) - optimizer.run(fmax=self.fmax) - self.nebtools = NEBTools(self.images[1:-1]) barrier, delta_E = self.nebtools.get_barrier() # noqa: N806 max_force = self.nebtools.get_fmax() @@ -445,14 +512,6 @@ def run(self) -> dict[str, float]: "max_force": max_force, } - # Optionally write out optimized images - output_structs( - self.images, - struct_path=self.struct_path, - write_results=self.write_results, - write_kwargs=self.results_write_kwargs, - ) - self.plot() if self.logger: diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index f31a4b8b..40c8853c 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -121,6 +121,7 @@ class CorrelationKwargs(TypedDict, total=True): Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr"] Properties = Literal["energy", "stress", "forces", "hessian"] PhononCalcs = Literal["bands", "dos", "pdos", "thermal"] +Interpolators = Literal["ase", "pymatgen"] class OutputKwargs(ASEWriteArgs, total=False): diff --git a/pyproject.toml b/pyproject.toml index 23803491..9c8266d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "mace-torch==0.3.9", "numpy<2.0.0,>=1.26.4", "phonopy<3.0.0,>=2.23.1", + "pymatgen>=2025.1.24", "pyyaml<7.0.0,>=6.0.1", "rich<14.0.0,>=13.9.1", "seekpath<2.0.0,>=1.9.7",