Skip to content

Commit

Permalink
PhononWorkChain: add max concurrent running pw workchains
Browse files Browse the repository at this point in the history
Fixes #52

Add the `settings.max_concurrent_base_workchains` input to the `PhononWorkChain`
to run only up to a maximum number of PwBaseWorkChain at the same time. This is
particularly useful for low symmetry materials where many displacements are generated.
This avoids to overload both the AiiDA daemon and especially the HPC. Moreover, this
even allows for running on the local machine (where AiiDA is installed) this workflow.
  • Loading branch information
bastonero committed Dec 22, 2024
1 parent 835266a commit d9cc29f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 8 deletions.
32 changes: 27 additions & 5 deletions src/aiida_vibroscopy/workflows/phonons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiida.common.lang import type_check
from aiida.engine import WorkChain, calcfunction, if_
from aiida.engine import WorkChain, calcfunction, if_, while_
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import create_kpoints_from_distance
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin
Expand Down Expand Up @@ -133,6 +133,10 @@ def define(cls, spec):
'settings.sleep_submission_time', valid_type=(int, float), non_db=True, default=3.0,
help='Time in seconds to wait before submitting subsequent displaced structure scf calculations.',
)
spec.input(
'settings.max_concurrent_base_workchains', valid_type=int, non_db=True, default=20,
help='Maximum number of concurrent running `PwBaseWorkChain`.'
)
spec.input(
'clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.'
Expand All @@ -143,7 +147,10 @@ def define(cls, spec):
cls.set_reference_kpoints,
cls.run_base_supercell,
cls.inspect_base_supercell,
cls.run_forces,
cls.run_supercells,
while_(cls.should_run_forces)(
cls.run_forces,
),
cls.inspect_all_runs,
cls.set_phonopy_data,
if_(cls.should_run_phonopy)(
Expand Down Expand Up @@ -379,8 +386,8 @@ def inspect_base_supercell(self):
fermi_energy = parameters.fermi_energy
self.ctx.is_insulator, _ = orm.find_bandgap(bands, fermi_energy=fermi_energy)

def run_forces(self):
"""Run an scf for each supercell with displacements."""
def run_supercells(self):
"""Run supercell with displacements."""
if self.ctx.plus_hubbard or self.ctx.old_plus_hubbard:
supercells = get_supercells_for_hubbard(
preprocess_data=self.ctx.preprocess_data,
Expand All @@ -389,12 +396,27 @@ def run_forces(self):
else:
supercells = self.ctx.preprocess_data.calcfunctions.get_supercells_with_displacements()

self.ctx.supercells = []
for key, value in supercells.items():
self.ctx.supercells.append((key, value))

self.out('supercells', supercells)

def should_run_forces(self):
"""Whether to run or not forces."""
return len(self.ctx.supercells) > 0

def run_forces(self):
"""Run an scf for each supercell with displacements."""
base_key = f'{self._RUN_PREFIX}_0'
base_out = self.ctx[base_key].outputs

for key, supercell in supercells.items():
n_base_parallel = self.inputs.settings.max_concurrent_base_workchains
if self.inputs.settings.max_concurrent_base_workchains < 0:
n_base_parallel = len(self.ctx.supercells)

Check warning on line 416 in src/aiida_vibroscopy/workflows/phonons/base.py

View check run for this annotation

Codecov / codecov/patch

src/aiida_vibroscopy/workflows/phonons/base.py#L416

Added line #L416 was not covered by tests

for _ in self.ctx.supercells[:n_base_parallel]:
key, supercell = self.ctx.supercells.pop(0)
num = key.split('_')[-1]
label = f'{self._RUN_PREFIX}_{num}'

Expand Down
39 changes: 36 additions & 3 deletions tests/workflows/phonons/test_phonon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#################################################################################
"""Tests for the :mod:`workflows.phonons.phonon` module."""
from aiida import orm
from aiida.common import AttributeDict
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
import pytest

Expand All @@ -29,13 +30,13 @@ def _generate_workchain_phonon(structure=None, append_inputs=None, return_inputs
inputs = {
'scf': scf_inputs,
'settings': {
'sleep_submission_time': 0.
'sleep_submission_time': 0.,
},
'symmetry': {},
}

if return_inputs:
return inputs
return AttributeDict(inputs)

if append_inputs is not None:
inputs.update(append_inputs)
Expand Down Expand Up @@ -223,23 +224,53 @@ def test_inspect_base_supercell(
assert process.ctx.is_insulator


@pytest.mark.usefixtures('aiida_profile')
def test_run_supercells(generate_workchain_phonon):
"""Test `PhononWorkChain.run_supercells` method."""
process = generate_workchain_phonon()
process.setup()
process.run_supercells()

assert 'supercells' in process.outputs
assert 'supercells' in process.ctx
assert 'supercell_1' in process.outputs['supercells']


@pytest.mark.usefixtures('aiida_profile')
def test_should_run_forces(generate_workchain_phonon):
"""Test `PhononWorkChain.should_run_forces` method."""
process = generate_workchain_phonon()
process.setup()
process.run_supercells()
assert process.should_run_forces()


@pytest.mark.usefixtures('aiida_profile')
def test_run_forces(generate_workchain_phonon, generate_base_scf_workchain_node):
"""Test `PhononWorkChain.run_forces` method."""
process = generate_workchain_phonon()
append_inputs = {
'settings': {
'sleep_submission_time': 0.,
'max_concurrent_base_workchains': 1,
}
}
process = generate_workchain_phonon(append_inputs=append_inputs)

process.setup()
process.set_reference_kpoints()
process.run_base_supercell()
process.run_supercells()

assert 'scf_supercell_0'

num_supercells = len(process.ctx.supercells)
process.ctx.scf_supercell_0 = generate_base_scf_workchain_node()
process.run_forces()

assert 'supercells' in process.outputs
assert 'supercell_1' in process.outputs['supercells']
assert 'scf_supercell_1' in process.ctx
assert num_supercells == len(process.ctx.supercells) + 1


@pytest.mark.usefixtures('aiida_profile')
Expand All @@ -251,6 +282,7 @@ def test_run_forces_with_hubbard(generate_workchain_phonon, generate_base_scf_wo
process.setup()
process.set_reference_kpoints()
process.run_base_supercell()
process.run_supercells()

assert 'scf_supercell_0'

Expand All @@ -261,6 +293,7 @@ def test_run_forces_with_hubbard(generate_workchain_phonon, generate_base_scf_wo
assert 'supercell_1' in process.outputs['supercells']
assert isinstance(process.outputs['supercells']['supercell_1'], HubbardStructureData)
assert 'scf_supercell_1' in process.ctx
assert len(process.ctx.supercells) == 0


@pytest.mark.parametrize(('expected_result', 'exit_status'),
Expand Down
4 changes: 4 additions & 0 deletions tests/workflows/protocols/test_phonon.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ def test_overrides(fixture_code, generate_structure):
'displacement_generator': {
'distance': 0.005
},
'settings': {
'max_concurrent_base_workchains': 1,
}
}
builder = PhononWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)

assert builder.primitive_matrix.get_list() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
assert builder.displacement_generator.get_dict() == {'distance': 0.005}
assert builder.settings.max_concurrent_base_workchains == 1


def test_phonon_properties(fixture_code, generate_structure):
Expand Down

0 comments on commit d9cc29f

Please sign in to comment.