Skip to content

Commit

Permalink
Add interpolation options
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Jan 31, 2025
1 parent 540ab96 commit 2f8aede
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 50 deletions.
159 changes: 109 additions & 50 deletions janus_core/calculations/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import copy
from typing import Any

Expand All @@ -12,13 +12,15 @@
from ase.mep import DyNEB, NEBTools
from ase.mep.neb import NEBOptimizer
from matplotlib.figure import Figure
from pymatgen.io.ase import AseAtomsAdaptor

from janus_core.calculations.base import BaseCalculation
from janus_core.calculations.geom_opt import GeomOpt
from janus_core.helpers.janus_types import (
Architectures,
ASEReadArgs,
Devices,
Interpolators,
OutputKwargs,
PathLike,
)
Expand All @@ -45,6 +47,12 @@ class NEB(BaseCalculation):
final_struct_path
Path of final structure for Nudged Elastic Band method. Required if
`final_struct` is None. Default is None.
band_structs
Band of ASE Atoms images to optimize, skipping interpolation between the initial
and final structures. Requires interpolator to be None.
band_path
Path of band of images to optimize, skipping interpolation between the initial
and final structures. Requires interpolator to be None.
arch
MLIP architecture to use for Nudged Elastic Band method. Default is "mace_mp".
device
Expand Down Expand Up @@ -83,6 +91,8 @@ class NEB(BaseCalculation):
{"k": 0.1, "climb": True, "method": "string"} for NEB,
{"fmax": 0.1, "dynamic_relaxation": True, "climb": True, "scale_fmax": 1.2} for
DynNEB, else {}.
interpolator
Choice of interpolation strategy. Default is "ase".
interpolation_kwargs
Keyword arguments to pass to neb_method.interpolate. Default is
{"method": "idpp"}.
Expand Down Expand Up @@ -115,6 +125,8 @@ def __init__(
init_struct_path: PathLike | None = None,
final_struct: Atoms | None = None,
final_struct_path: PathLike | None = None,
band_structs: Sequence[Atoms] | None = None,
band_path: PathLike | None = None,
arch: Architectures = "mace_mp",
device: Devices = "cpu",
model_path: PathLike | None = None,
Expand All @@ -131,6 +143,7 @@ def __init__(
write_results: bool = False,
write_kwargs: OutputKwargs | None = None,
neb_kwargs: dict[str, Any] | None = None,
interpolator: Interpolators | None = "ase",
interpolation_kwargs: dict[str, Any] | None = None,
neb_optimizer: Callable | str = NEBOptimizer,
fmax: float = 0.1,
Expand Down Expand Up @@ -159,6 +172,12 @@ def __init__(
final_struct_path
Path of final structure for Nudged Elastic Band method. Required if
`final_struct` is None. Default is None.
band_structs
Band of ASE Atoms images to optimize, skipping interpolation between the
initial and final structures. Requires interpolator to be None.
band_path
Path of band of images to optimize, skipping interpolation between the
initial and final structures. Requires interpolator to be None.
arch
MLIP architecture to use for Nudged Elastic Band method. Default is
"mace_mp".
Expand Down Expand Up @@ -198,6 +217,8 @@ def __init__(
{"k": 0.1, "climb": True, "method": "string"} for NEB,
{"fmax": 0.1, "dynamic_relaxation": True, "climb": True, "scale_fmax": 1.2}
for DynNEB, else {}.
interpolator
Choice of interpolation strategy. Default is "ase".
interpolation_kwargs
Keyword arguments to pass to neb_method.interpolate. Default is
{"method": "idpp"}.
Expand Down Expand Up @@ -240,6 +261,7 @@ def __init__(
self.write_results = write_results
self.write_kwargs = write_kwargs
self.neb_kwargs = neb_kwargs
self.interpolator = interpolator
self.interpolation_kwargs = interpolation_kwargs
self.neb_optimizer = neb_optimizer
self.fmax = fmax
Expand All @@ -252,8 +274,22 @@ def __init__(
if self.n_images <= 0 or not isinstance(self.n_images, int):
raise ValueError("`n_images` must be an integer greater than 0.")

# Read last image by default
read_kwargs.setdefault("index", -1)
if band_structs or band_path:
if init_struct or init_struct_path or final_struct or final_struct_path:
raise ValueError(
"Band cannot be specified in combination with an initial or final "
"structure"
)
if interpolator is not None:
raise ValueError("An interpolator cannot when specifying the band")

init_struct = band_structs
init_struct_path = band_path
# Read all image by default for band
read_kwargs.setdefault("index", ":")
else:
# Read last image by default for init_struct
read_kwargs.setdefault("index", -1)

# Initialise structures and logging
super().__init__(
Expand All @@ -264,7 +300,7 @@ def __init__(
device=device,
model_path=model_path,
read_kwargs=read_kwargs,
sequence_allowed=False,
sequence_allowed=True,
calc_kwargs=calc_kwargs,
set_calc=set_calc,
attach_logger=attach_logger,
Expand All @@ -274,39 +310,43 @@ def __init__(
file_prefix=file_prefix,
)

if not self.struct.calc:
raise ValueError("Please attach a calculator to `init_struct`.")

# Use initial structure (path) for default file paths etc.
self.init_struct = self.struct
self.init_struct_path = self.struct_path

self.final_struct = input_structs(
struct=final_struct,
struct_path=final_struct_path,
read_kwargs=read_kwargs,
sequence_allowed=False,
set_calc=False,
)
self.final_struct_path = final_struct_path
self.final_struct.calc = copy(self.struct.calc)
if interpolator is not None:
if not isinstance(self.struct, Atoms):
raise ValueError("`init_struct` must be a single structure.")
if not self.struct.calc:
raise ValueError("Please attach a calculator to `init_struct`.")

# Use initial structure (path) for default file paths etc.
self.init_struct = self.struct
self.init_struct_path = self.struct_path

self.final_struct = input_structs(
struct=final_struct,
struct_path=final_struct_path,
read_kwargs=read_kwargs,
sequence_allowed=False,
set_calc=False,
)
self.final_struct_path = final_struct_path
self.final_struct.calc = copy(self.struct.calc)
else:
if not isinstance(self.struct, Sequence):
raise ValueError("`images` must include multiple structures.")

# Set default interpolation method
interpolation_kwargs.setdefault("method", "idpp")
# Set default interpolation kwargs
if self.interpolator == "ase":
interpolation_kwargs.setdefault("method", "idpp")
if self.interpolator == "pymatgen":
interpolation_kwargs.setdefault("interpolate_lattices", False)
interpolation_kwargs.setdefault("autosort_tol", 0.5)

# Set output file defaults
self.plot_file = self._build_filename("neb-plot.svg").absolute()

self.images_write_kwargs = self.write_kwargs.copy()
self.images_write_kwargs["filename"] = self._build_filename(
self.write_kwargs["filename"] = self._build_filename(
"neb-images.extxyz"
).absolute()

self.results_write_kwargs = self.write_kwargs.copy()
self.results_write_kwargs["filename"] = self._build_filename(
"neb-opt_images.extxyz"
).absolute()

if self.minimize:
set_minimize_logging(
self.logger, self.minimize_kwargs, self.log_kwargs, track_carbon
Expand Down Expand Up @@ -390,6 +430,41 @@ def plot(self) -> Figure | None:

return fig

def set_interpolator(self) -> None:
"""Interpolate images to create initial band."""
match self.interpolator:
case "ase":
# Create band of images and attach calculators
self.images = [self.init_struct]
self.images += [self.init_struct.copy() for i in range(self.n_images)]
for image in self.images[1:]:
image.calc = copy(self.init_struct.calc)
self.images += [self.final_struct]

self.neb = self.neb_method(self.images, **self.neb_kwargs)
self.neb.interpolate(**self.interpolation_kwargs)

case "pymatgen":
# Create band of images and attach calculators
py_start_struct = AseAtomsAdaptor.get_structure(self.init_struct)
py_final_struct = AseAtomsAdaptor.get_structure(self.final_struct)
py_images = py_start_struct.interpolate(
py_final_struct,
nimages=self.n_images + 1,
**self.interpolation_kwargs,
)
self.images = [image.to_ase_atoms() for image in py_images]
for image in self.images:
image.calc = copy(self.init_struct.calc)

self.neb = self.neb_method(self.images, **self.neb_kwargs)

case None:
# Band already created
pass
case _:
raise ValueError("Invalid interpolator selected")

def run(self) -> dict[str, float]:
"""
Run Nudged Elastic Band method.
Expand All @@ -415,27 +490,19 @@ def run(self) -> dict[str, float]:
)
GeomOpt(self.final_struct, **self.minimize_kwargs).run()

# Create band of images and attach calculators
self.images = [self.init_struct]
self.images += [self.init_struct.copy() for i in range(self.n_images)]
for image in self.images[1:]:
image.calc = copy(self.init_struct.calc)
self.images += [self.final_struct]
self.set_interpolator()

optimizer = self.neb_optimizer(self.neb, **self.optimizer_kwargs)
optimizer.run(fmax=self.fmax)

# Optionally write band images to file
output_structs(
images=self.images,
struct_path=self.struct_path,
write_results=self.write_images,
write_kwargs=self.images_write_kwargs,
write_kwargs=self.write_kwargs,
)

self.neb = self.neb_method(self.images, **self.neb_kwargs)
self.neb.interpolate(**self.interpolation_kwargs)

optimizer = self.neb_optimizer(self.neb, **self.optimizer_kwargs)
optimizer.run(fmax=self.fmax)

self.nebtools = NEBTools(self.images[1:-1])
barrier, delta_E = self.nebtools.get_barrier() # noqa: N806
max_force = self.nebtools.get_fmax()
Expand All @@ -445,14 +512,6 @@ def run(self) -> dict[str, float]:
"max_force": max_force,
}

# Optionally write out optimized images
output_structs(
self.images,
struct_path=self.struct_path,
write_results=self.write_results,
write_kwargs=self.results_write_kwargs,
)

self.plot()

if self.logger:
Expand Down
1 change: 1 addition & 0 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class CorrelationKwargs(TypedDict, total=True):
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr"]
Properties = Literal["energy", "stress", "forces", "hessian"]
PhononCalcs = Literal["bands", "dos", "pdos", "thermal"]
Interpolators = Literal["ase", "pymatgen"]


class OutputKwargs(ASEWriteArgs, total=False):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"mace-torch==0.3.9",
"numpy<2.0.0,>=1.26.4",
"phonopy<3.0.0,>=2.23.1",
"pymatgen>=2025.1.24",
"pyyaml<7.0.0,>=6.0.1",
"rich<14.0.0,>=13.9.1",
"seekpath<2.0.0,>=1.9.7",
Expand Down

0 comments on commit 2f8aede

Please sign in to comment.