Skip to content

Commit

Permalink
Merge pull request #253 from RaulPPelaez/check_errors
Browse files Browse the repository at this point in the history
Add check_errors option
  • Loading branch information
RaulPPelaez authored Jan 18, 2024
2 parents b1f2f5a + 2860772 commit 93d3d8b
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 11 deletions.
11 changes: 11 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ These requirements correspond to a particular rotation of the system and reduced

.. note:: The box defined by the vectors :math:`\vec{a} = (L_x, 0, 0)`, :math:`\vec{b} = (0, L_y, 0)`, and :math:`\vec{c} = (0, 0, L_z)` correspond to a rectangular box. In this case, the input option in the :ref:`configuration file <configuration-file>` would be ``box-vecs: [[L_x, 0, 0], [0, L_y, 0], [0, 0, L_z]]``.


CUDA Graphs
============

TensorNet is capturable into a `CUDA graph <https://developer.nvidia.com/blog/cuda-graphs/>`_ with the right options. This can dramatically increase performance during inference. The dynamically-shaped nature of training makes CUDA graphs not an option in most practical cases.

For TensorNet to be CUDA-graph compatible, `check_errors` must be `False` and `static_shapes` must be `True`. Manually capturing a piece of code can be challenging, instead, to take advantage of CUDA graphs you can use :py:mod:`torchmdnet.calculators.External`, which helps integrating a Torchmd-NET model into another code, or `OpenMM-Torch <https://github.com/openmm/openmm-torch>`_ if you are using OpenMM.



Multi-Node Training
===================

Expand All @@ -85,6 +95,7 @@ In order to train models on multiple nodes some environment variables have to be
- Due to the way PyTorch Lightning calculates the number of required DDP processes, all nodes must use the same number of GPUs. Otherwise training will not start or crash.
- We observe a 50x decrease in performance when mixing nodes with different GPU architectures (tested with RTX 2080 Ti and RTX 3090).


Developer Guide
---------------

Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def test_cuda_graph_compatible(model_name):
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_error": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
Expand Down
36 changes: 28 additions & 8 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def create_model(args, prior_model=None, mean=None, std=None):
dtype = dtype_mapping[args["precision"]]
if "box_vecs" not in args:
args["box_vecs"] = None
if "check_errors" not in args:
args["check_errors"] = True
if "static_shapes" not in args:
args["static_shapes"] = False
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -42,8 +46,11 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_lower=args["cutoff_lower"],
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
check_errors=args["check_errors"],
max_num_neighbors=args["max_num_neighbors"],
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype) if args["box_vecs"] is not None else None,
box_vecs=torch.tensor(args["box_vecs"], dtype=dtype)
if args["box_vecs"] is not None
else None,
dtype=dtype,
)

Expand Down Expand Up @@ -87,6 +94,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
is_equivariant = False
representation_model = TensorNet(
equivariance_invariance_group=args["equivariance_invariance_group"],
static_shapes=args["static_shapes"],
**shared_args,
)
else:
Expand Down Expand Up @@ -153,11 +161,21 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
# The following are for backward compatibility with models created when atomref was
# the only supported prior.
if "prior_model.initial_atomref" in state_dict:
warnings.warn(
"prior_model.initial_atomref is deprecated and will be removed in a future version. Use prior_model.0.initial_atomref instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.initial_atomref"] = state_dict[
"prior_model.initial_atomref"
]
del state_dict["prior_model.initial_atomref"]
if "prior_model.atomref.weight" in state_dict:
warnings.warn(
"prior_model.atomref.weight is deprecated and will be removed in a future version. Use prior_model.0.atomref.weight instead.",
category=DeprecationWarning,
stacklevel=2,
)
state_dict["prior_model.0.atomref.weight"] = state_dict[
"prior_model.atomref.weight"
]
Expand Down Expand Up @@ -201,7 +219,7 @@ def create_prior_models(args, dataset=None):


class TorchMD_Net(nn.Module):
""" The main TorchMD-Net model.
"""The main TorchMD-Net model.
The TorchMD_Net class combines a given representation model (such as the equivariant transformer),
an output model (such as the scalar output module), and a prior model (such as the atomref prior).
Expand Down Expand Up @@ -311,15 +329,15 @@ def forward(
Args:
z (Tensor): Atomic numbers of the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
pos (Tensor): Atomic positions in the molecule. Shape: (N, 3).
batch (Tensor, optional): Batch indices for the atoms in the molecule. Shape: (N,).
box (Tensor, optional): Box vectors. Shape (3, 3).
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
q (Tensor, optional): Atomic charges in the molecule. Shape: (N,).
s (Tensor, optional): Atomic spins in the molecule. Shape: (N,).
extra_args (Dict[str, Tensor], optional): Extra arguments to pass to the prior model.
Returns:
Tuple[Tensor, Optional[Tensor]]: The output of the model and the derivative of the output with respect to the positions if derivative is True, None otherwise.
Expand All @@ -330,7 +348,9 @@ def forward(
if self.derivative:
pos.requires_grad_(True)
# run the potentially wrapped representation model
x, v, z, pos, batch = self.representation_model(z, pos, batch, box=box, q=q, s=s)
x, v, z, pos, batch = self.representation_model(
z, pos, batch, box=box, q=q, s=s
)
# apply the output network
x = self.output_model.pre_reduce(x, v, z, pos, batch)

Expand Down
7 changes: 5 additions & 2 deletions torchmdnet/models/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class TensorNet(nn.Module):
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
static_shapes (bool, optional): Whether to enforce static shapes.
Makes the model CUDA-graph compatible.
Makes the model CUDA-graph compatible if check_errors is set to False.
(default: :obj:`True`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
"""

Expand All @@ -134,6 +136,7 @@ def __init__(
max_z=128,
equivariance_invariance_group="O(3)",
static_shapes=True,
check_errors=True,
dtype=torch.float32,
box_vecs=None,
):
Expand Down Expand Up @@ -202,7 +205,7 @@ def __init__(
max_num_pairs=-max_num_neighbors,
return_vecs=True,
loop=True,
check_errors=False,
check_errors=check_errors,
resize_to_fit=not self.static_shapes,
box=box_vecs,
long_edge_index=True,
Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/models/torchmd_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class TorchMD_ET(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
"""

Expand All @@ -94,6 +96,7 @@ def __init__(
cutoff_upper=5.0,
max_z=100,
max_num_neighbors=32,
check_errors=True,
box_vecs=None,
dtype=torch.float32,
):
Expand Down Expand Up @@ -140,6 +143,7 @@ def __init__(
loop=True,
box=box_vecs,
long_edge_index=True,
check_errors=check_errors,
)
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
Expand Down
4 changes: 4 additions & 0 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class TorchMD_GN(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
"""

Expand All @@ -101,6 +103,7 @@ def __init__(
cutoff_upper=5.0,
max_z=100,
max_num_neighbors=32,
check_errors=True,
aggr="add",
dtype=torch.float32,
box_vecs=None,
Expand Down Expand Up @@ -144,6 +147,7 @@ def __init__(
max_num_pairs=-max_num_neighbors,
box=box_vecs,
long_edge_index=True,
check_errors=check_errors,
)

self.distance_expansion = rbf_class_mapping[rbf_type](
Expand Down
6 changes: 5 additions & 1 deletion torchmdnet/models/torchmd_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class TorchMD_T(nn.Module):
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
(default: :obj:`None`)
check_errors (bool, optional): Whether to check for errors in the distance module.
(default: :obj:`True`)
"""

Expand All @@ -91,6 +93,7 @@ def __init__(
distance_influence="both",
cutoff_lower=0.0,
cutoff_upper=5.0,
check_errors=True,
max_z=100,
max_num_neighbors=32,
dtype=torch.float,
Expand Down Expand Up @@ -133,7 +136,8 @@ def __init__(
max_num_pairs=-max_num_neighbors,
loop=True,
box=box_vecs,
long_edge_index=True
long_edge_index=True,
check_errors=check_errors,
)

self.distance_expansion = rbf_class_mapping[rbf_type](
Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def get_argparse():
`a[1] = a[2] = b[2] = 0`;`a[0] >= 2*cutoff, b[1] >= 2*cutoff, c[2] >= 2*cutoff`;`a[0] >= 2*b[0]`;`a[0] >= 2*c[0]`;`b[1] >= 2*c[1]`;
These requirements correspond to a particular rotation of the system and reduced form of the vectors, as well as the requirement that the cutoff be no larger than half the box width.
Example: [[1,0,0],[0,1,0],[0,0,1]]""")
parser.add_argument('--static_shapes', type=bool, default=False, help='If true, TensorNet will use statically shaped tensors for the network, making it capturable into a CUDA graphs. In some situations static shapes can lead to a speedup, but it increases memory usage.')

# other args
parser.add_argument('--check_errors', type=bool, default=True, help='Will check if max_num_neighbors is not enough to contain all neighbors. This is incompatible with CUDA graphs.')
parser.add_argument('--derivative', default=False, type=bool, help='If true, take the derivative of the prediction w.r.t coordinates')
parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model')
parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model')
Expand Down

0 comments on commit 93d3d8b

Please sign in to comment.