diff --git a/examples/TensorNet-QM9.yaml b/examples/TensorNet-QM9.yaml index 6ab98a2c..4000b53f 100644 --- a/examples/TensorNet-QM9.yaml +++ b/examples/TensorNet-QM9.yaml @@ -57,3 +57,5 @@ weight_decay: 0.0 box_vecs: null charge: false spin: false +static_shapes: True +check_errors: False diff --git a/examples/TensorNet-rMD17.yaml b/examples/TensorNet-rMD17.yaml index 737e4c95..bc14f73e 100644 --- a/examples/TensorNet-rMD17.yaml +++ b/examples/TensorNet-rMD17.yaml @@ -57,3 +57,5 @@ weight_decay: 0.0 box_vecs: null charge: false spin: false +static_shapes: True +check_errors: False diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index bf408aa3..8d913ec9 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -52,6 +52,7 @@ def reduce(self, x, batch): self.dim_size ) ) + # self.dim_size = 1 return scatter(x, batch, dim=0, dim_size=self.dim_size, reduce=self.reduce_op) def post_reduce(self, x): diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e..deae9161 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -10,55 +10,66 @@ OptimizedDistance, rbf_class_mapping, act_class_mapping, + MLP, + nvtx_annotate, + nvtx_range, ) __all__ = ["TensorNet"] -torch.set_float32_matmul_precision("high") +torch.set_float32_matmul_precision("medium") torch.backends.cuda.matmul.allow_tf32 = True +@nvtx_annotate("vector_to_skewtensor") def vector_to_skewtensor(vector): """Creates a skew-symmetric tensor from a vector.""" - batch_size = vector.size(0) + batch_size = vector.shape[:-1] zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype) tensor = torch.stack( ( zero, - -vector[:, 2], - vector[:, 1], - vector[:, 2], + -vector[..., 2], + vector[..., 1], + vector[..., 2], zero, - -vector[:, 0], - -vector[:, 1], - vector[:, 0], + -vector[..., 0], + -vector[..., 1], + vector[..., 0], zero, ), - dim=1, + dim=-1, ) - tensor = tensor.view(-1, 3, 3) + tensor = tensor.view(*batch_size, 3, 3) return tensor.squeeze(0) +@nvtx_annotate("vector_to_symtensor") def vector_to_symtensor(vector): """Creates a symmetric traceless tensor from the outer product of a vector with itself.""" tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2)) - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ - ..., None, None - ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) - S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I + S = tensor_to_symtensor(tensor) return S +@nvtx_annotate("tensor_to_symtensor") +def tensor_to_symtensor(tensor): + S = 0.5 * (tensor + tensor.transpose(-2, -1)) + I = (tensor.diagonal(dim1=-1, dim2=-2)).mean(-1) + S.diagonal(dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) + return S + + +@nvtx_annotate("decompose_tensor") def decompose_tensor(tensor): """Full tensor decomposition into irreducible components.""" - I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[ - ..., None, None - ] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype) A = 0.5 * (tensor - tensor.transpose(-2, -1)) - S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I + S = tensor - A + I = (tensor.diagonal(dim1=-1, dim2=-2)).mean(-1) + S.diagonal(dim1=-1, dim2=-2).sub_(I.unsqueeze(-1)) return I, A, S +@nvtx_annotate("tensor_norm") def tensor_norm(tensor): """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) @@ -173,7 +184,6 @@ def __init__( act_class, cutoff_lower, cutoff_upper, - trainable_rbf, max_z, dtype, ) @@ -220,6 +230,47 @@ def reset_parameters(self): self.linear.reset_parameters() self.out_norm.reset_parameters() + @nvtx_annotate("make_static") + def _make_static( + self, num_nodes: int, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + if self.static_shapes: + mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) + # I trick the model into thinking that the masked edges pertain to the extra atom + # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs + edge_index = edge_index.masked_fill(mask, num_nodes) + edge_weight = edge_weight.masked_fill(mask[0], 0) + edge_vec = edge_vec.masked_fill( + mask[0].unsqueeze(-1).expand_as(edge_vec), 0 + ) + return edge_index, edge_weight, edge_vec + + @nvtx_annotate("compute_neighbors") + def _compute_neighbors( + self, pos: Tensor, batch: Tensor, box: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Tensor]: + edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) + # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] + assert ( + edge_vec is not None + ), "Distance module did not return directional information" + edge_index, edge_weight, edge_vec = self._make_static( + pos.shape[0], edge_index, edge_weight, edge_vec + ) + return edge_index, edge_weight, edge_vec + + @nvtx_annotate("output") + def output(self, X: Tensor) -> Tensor: + I, A, S = decompose_tensor(X) # shape: (n_atoms, hidden_channels, 3, 3) + x = torch.cat( + (3 * I**2, tensor_norm(A), tensor_norm(S)), dim=-1 + ) # shape: (n_atoms, 3*hidden_channels) + x = self.out_norm(x) # shape: (n_atoms, 3*hidden_channels) + x = self.act(self.linear((x))) # shape: (n_atoms, hidden_channels) + return x + + @nvtx_annotate("TensorNet") def forward( self, z: Tensor, @@ -229,48 +280,66 @@ def forward( q: Optional[Tensor] = None, s: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - # Obtain graph, with distances and relative position vectors - edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) - # This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor] - assert ( - edge_vec is not None - ), "Distance module did not return directional information" - # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom + if self.static_shapes: + z = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) else: q = q[batch] - zp = z - if self.static_shapes: - mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) - zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) - # I trick the model into thinking that the masked edges pertain to the extra atom - # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs - edge_index = edge_index.masked_fill(mask, z.shape[0]) - edge_weight = edge_weight.masked_fill(mask[0], 0) - edge_vec = edge_vec.masked_fill( - mask[0].unsqueeze(-1).expand_as(edge_vec), 0 - ) - edge_attr = self.distance_expansion(edge_weight) - mask = edge_index[0] == edge_index[1] - # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. - # I avoid dividing by zero by setting the weight of self edges and self loops to 1 - edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) - X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) + edge_index, edge_weight, edge_vec = self._compute_neighbors(pos, batch, box) + edge_attr = self.distance_expansion(edge_weight) # shape: (n_edges, num_rbf) + X = self.tensor_embedding( + z, edge_index, edge_weight, edge_vec, edge_attr + ) # shape: (n_atoms, hidden_channels, 3, 3) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, q) - I, A, S = decompose_tensor(X) - x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) - x = self.out_norm(x) - x = self.act(self.linear((x))) - # # Remove the extra atom + X = layer( + X, edge_index, edge_weight, edge_attr, q + ) # shape: (n_atoms, hidden_channels, 3, 3) + x = self.output(X) # shape: (n_atoms, hidden_channels) + # Remove the extra atom if self.static_shapes: x = x[:-1] + z = z[:-1] return x, None, z, pos, batch +class TensorLinear(nn.Module): + + def __init__(self, in_channels, out_channels, dtype=torch.float32): + super(TensorLinear, self).__init__() + self.linearI = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + self.linearA = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + self.linearS = nn.Linear(in_channels, out_channels, bias=False, dtype=dtype) + + def reset_parameters(self): + self.linearI.reset_parameters() + self.linearA.reset_parameters() + self.linearS.reset_parameters() + + @nvtx_annotate("TensorLinear") + def forward(self, X: Tensor, factor: Optional[Tensor] = None) -> Tensor: + if factor is None: + factor = ( + torch.ones(1, device=X.device, dtype=X.dtype) + .unsqueeze(-1) + .unsqueeze(-1) + ).expand(-1, -1, 3) + I, A, S = decompose_tensor(X) + I = self.linearI(I) * factor[..., 0] + A = ( + self.linearA(A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * factor[..., 1, None, None] + ) + S = ( + self.linearS(S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * factor[..., 2, None, None] + ) + dX = A + S + dX.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1)) + return dX + + class TensorEmbedding(nn.Module): """Tensor embedding layer. @@ -284,7 +353,6 @@ def __init__( activation, cutoff_lower, cutoff_upper, - trainable_rbf=False, max_z=128, dtype=torch.float32, ): @@ -299,11 +367,7 @@ def __init__( self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype) self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype) self.act = activation() - self.linears_tensor = nn.ModuleList() - for _ in range(3): - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) - ) + self.linear_tensor = TensorLinear(hidden_channels, hidden_channels) self.linears_scalar = nn.ModuleList() self.linears_scalar.append( nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype) @@ -312,6 +376,8 @@ def __init__( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype) + self.num_rbf = num_rbf + self.hidden_channels = hidden_channels self.reset_parameters() def reset_parameters(self): @@ -320,88 +386,133 @@ def reset_parameters(self): self.distance_proj3.reset_parameters() self.emb.reset_parameters() self.emb2.reset_parameters() - for linear in self.linears_tensor: - linear.reset_parameters() + self.linear_tensor.reset_parameters() for linear in self.linears_scalar: linear.reset_parameters() self.init_norm.reset_parameters() - def _get_atomic_number_message(self, z: Tensor, edge_index: Tensor) -> Tensor: + @nvtx_annotate("normalize_edges") + def _normalize_edges( + self, edge_index: Tensor, edge_weight: Tensor, edge_vec: Tensor + ) -> Tensor: + mask = edge_index[0] == edge_index[1] + # Normalizing edge vectors by their length can result in NaNs, breaking Autograd. + # I avoid dividing by zero by setting the weight of self edges and self loops to 1 + edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) + return edge_vec + + @nvtx_annotate("compute_edge_atomic_features") + def _compute_edge_atomic_features(self, z: Tensor, edge_index: Tensor) -> Tensor: Z = self.emb(z) Zij = self.emb2( Z.index_select(0, edge_index.t().reshape(-1)).view( -1, self.hidden_channels * 2 ) - )[..., None, None] + ) return Zij - def _get_tensor_messages( - self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij - eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[ - None, None, ... - ] - Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye - Aij = ( - self.distance_proj2(edge_attr)[..., None, None] - * C - * vector_to_skewtensor(edge_vec_norm)[..., None, :, :] - ) - Sij = ( + @nvtx_annotate("compute_edge_tensor_features") + def _compute_node_tensor_features( + self, + z: Tensor, + edge_index, + edge_weight: Tensor, + edge_vec: Tensor, + edge_attr: Tensor, + ) -> Tensor: + edge_vec_norm = self._normalize_edges( + edge_index, edge_weight, edge_vec + ) # shape: (n_edges, 3) + Zij = self.cutoff(edge_weight)[:, None] * self._compute_edge_atomic_features( + z, edge_index + ) # shape: (n_edges, hidden_channels) + A = ( + self.distance_proj2(edge_attr)[ + ..., None + ] # shape: (n_edges, hidden_channels, 1) + * Zij[..., None] # shape: (n_edges, hidden_channels, 1) + * edge_vec_norm[:, None, :] # shape: (n_edges, 1, 3) + ) # shape: (n_edges, hidden_channels, 3) + A = self._aggregate_edge_features( + z.shape[0], A, edge_index[0] + ) # shape: (n_atoms, hidden_channels, 3) + A = vector_to_skewtensor(A) # shape: (n_atoms, hidden_channels, 3, 3) + I = self.distance_proj1(edge_attr) * Zij + I = self._aggregate_edge_features(z.shape[0], I, edge_index[0]) + # Outer product of edge vectors + tensor = torch.matmul( + edge_vec_norm.unsqueeze(-1), edge_vec_norm.unsqueeze(-2) + ) # shape: (n_edges, 3, 3) + tensor = ( self.distance_proj3(edge_attr)[..., None, None] - * C - * vector_to_symtensor(edge_vec_norm)[..., None, :, :] - ) - return Iij, Aij, Sij + * Zij[..., None, None] + * tensor[..., None, :, :] + ) # shape: (n_edges, hidden_channels, 3, 3) + tensor = self._aggregate_edge_features( + z.shape[0], tensor, edge_index[0] + ) # shape: (n_atoms, hidden_channels, 3, 3) + S = tensor_to_symtensor(tensor) # shape: (n_atoms, hidden_channels, 3, 3) + features = A + S + features.diagonal(dim1=-2, dim2=-1).add_(I.unsqueeze(-1)) + return features + + @nvtx_annotate("aggregate_edge_features") + def _aggregate_edge_features( + self, num_atoms: int, T: Tensor, source_indices: Tensor + ) -> Tensor: + targetI = torch.zeros(num_atoms, *T.shape[1:], device=T.device, dtype=T.dtype) + I = targetI.index_add(dim=0, index=source_indices, source=T) + return I + + @nvtx_annotate("norm_mlp") + def _norm_mlp(self, norm): + norm = self.init_norm(norm) + for linear_scalar in self.linears_scalar: + norm = self.act(linear_scalar(norm)) + norm = norm.reshape(-1, self.hidden_channels, 3) + return norm + @nvtx_annotate("TensorEmbedding") def forward( self, z: Tensor, edge_index: Tensor, edge_weight: Tensor, - edge_vec_norm: Tensor, + edge_vec: Tensor, edge_attr: Tensor, ) -> Tensor: - Zij = self._get_atomic_number_message(z, edge_index) - Iij, Aij, Sij = self._get_tensor_messages( - Zij, edge_weight, edge_vec_norm, edge_attr - ) - source = torch.zeros( - z.shape[0], self.hidden_channels, 3, 3, device=z.device, dtype=Iij.dtype - ) - I = source.index_add(dim=0, index=edge_index[0], source=Iij) - A = source.index_add(dim=0, index=edge_index[0], source=Aij) - S = source.index_add(dim=0, index=edge_index[0], source=Sij) - norm = self.init_norm(tensor_norm(I + A + S)) - for linear_scalar in self.linears_scalar: - norm = self.act(linear_scalar(norm)) - norm = norm.reshape(-1, self.hidden_channels, 3) - I = ( - self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 0, None, None] - ) - A = ( - self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 1, None, None] - ) - S = ( - self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - * norm[..., 2, None, None] - ) - X = I + A + S + X = self._compute_node_tensor_features( + z, edge_index, edge_weight, edge_vec, edge_attr + ) # shape: (n_atoms, hidden_channels, 3, 3) + norm = self._norm_mlp(tensor_norm(X)) # shape: (n_atoms, hidden_channels) + X = self.linear_tensor(X, norm) # shape: (n_atoms, hidden_channels, 3, 3) return X -def tensor_message_passing( - edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int -) -> Tensor: - """Message passing for tensors.""" - msg = factor * tensor.index_select(0, edge_index[1]) - shape = (natoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]) - tensor_m = torch.zeros(*shape, device=tensor.device, dtype=tensor.dtype) +@nvtx_annotate("compute_tensor_edge_features") +def compute_tensor_edge_features(X, edge_index, factor): + I, A, S = decompose_tensor(X) + msg = factor[..., 1, None, None] * A.index_select(0, edge_index[1]) + factor[ + ..., 2, None, None + ] * S.index_select(0, edge_index[1]) + msg.diagonal(dim1=-2, dim2=-1).add_( + factor[..., 0, None] * I.index_select(0, edge_index[1]).unsqueeze(-1) + ) + return msg + + +@nvtx_annotate("tensor_message_passing") +def tensor_message_passing(n_atoms: int, edge_index: Tensor, tensor: Tensor) -> Tensor: + msg = tensor.index_select( + 0, edge_index[1] + ) # shape = (n_edges, hidden_channels, 3, 3) + tensor_m = torch.zeros( + (n_atoms, tensor.shape[1], tensor.shape[2], tensor.shape[3]), + device=tensor.device, + dtype=tensor.dtype, + ) tensor_m = tensor_m.index_add(0, edge_index[0], msg) - return tensor_m + return tensor_m # shape = (n_atoms, hidden_channels, 3, 3) class Interaction(nn.Module): @@ -435,11 +546,8 @@ def __init__( self.linears_scalar.append( nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype) ) - self.linears_tensor = nn.ModuleList() - for _ in range(6): - self.linears_tensor.append( - nn.Linear(hidden_channels, hidden_channels, bias=False) - ) + self.tensor_linear_in = TensorLinear(hidden_channels, hidden_channels) + self.tensor_linear_out = TensorLinear(hidden_channels, hidden_channels) self.act = activation() self.equivariance_invariance_group = equivariance_invariance_group self.reset_parameters() @@ -447,9 +555,33 @@ def __init__( def reset_parameters(self): for linear in self.linears_scalar: linear.reset_parameters() - for linear in self.linears_tensor: - linear.reset_parameters() + self.tensor_linear_in.reset_parameters() + self.tensor_linear_out.reset_parameters() + + @nvtx_annotate("update_tensor_node_features") + def _update_tensor_node_features(self, X, X_aggregated): + X = self.tensor_linear_in(X) + B = torch.matmul(X, X_aggregated) + if self.equivariance_invariance_group == "O(3)": + A = torch.matmul(X_aggregated, X) + elif self.equivariance_invariance_group == "SO(3)": + A = B + else: + raise ValueError("Unknown equivariance group") + Xnew = A + B + return Xnew + + @nvtx_annotate("compute_vector_node_features") + def _compute_vector_node_features(self, edge_attr, edge_weight): + C = self.cutoff(edge_weight) + for linear_scalar in self.linears_scalar: + edge_attr = self.act(linear_scalar(edge_attr)) + edge_attr = (edge_attr * C.view(-1, 1)).reshape( + edge_attr.shape[0], self.hidden_channels, 3 + ) + return edge_attr + @nvtx_annotate("Interaction") def forward( self, X: Tensor, @@ -458,40 +590,26 @@ def forward( edge_attr: Tensor, q: Tensor, ) -> Tensor: - C = self.cutoff(edge_weight) - for linear_scalar in self.linears_scalar: - edge_attr = self.act(linear_scalar(edge_attr)) - edge_attr = (edge_attr * C.view(-1, 1)).reshape( - edge_attr.shape[0], self.hidden_channels, 3 - ) - X = X / (tensor_norm(X) + 1)[..., None, None] - I, A, S = decompose_tensor(X) - I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - Y = I + A + S - Im = tensor_message_passing( - edge_index, edge_attr[..., 0, None, None], I, X.shape[0] - ) - Am = tensor_message_passing( - edge_index, edge_attr[..., 1, None, None], A, X.shape[0] - ) - Sm = tensor_message_passing( - edge_index, edge_attr[..., 2, None, None], S, X.shape[0] - ) - msg = Im + Am + Sm - if self.equivariance_invariance_group == "O(3)": - A = torch.matmul(msg, Y) - B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B)) - if self.equivariance_invariance_group == "SO(3)": - B = torch.matmul(Y, msg) - I, A, S = decompose_tensor(2 * B) - normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] - I, A, S = I / normp1, A / normp1, S / normp1 - I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - dX = I + A + S - X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2) + X = ( + X / (tensor_norm(X) + 1)[..., None, None] + ) # shape (n_atoms, hidden_channels, 3, 3) + node_features = self._compute_vector_node_features( + edge_attr, edge_weight + ) # shape (n_edges, hidden_channels, 3) + Y_edges = compute_tensor_edge_features( + X, edge_index, node_features + ) # shape (n_edges, hidden_channels, 3, 3) + Y_aggregated = tensor_message_passing( + X.shape[0], edge_index, Y_edges + ) # shape (n_atoms, hidden_channels, 3, 3) + Xnew = self._update_tensor_node_features( + X, Y_aggregated + ) # shape (n_atoms, hidden_channels, 3, 3) + dX = self.tensor_linear_out( + Xnew / (tensor_norm(Xnew) + 1)[..., None, None] + ) # shape (n_atoms, hidden_channels, 3, 3) + charge_factor = 1 + 0.1 * q[..., None, None, None] + X = ( + X + (dX + torch.matrix_power(dX, 2)) * charge_factor + ) # shape (n_atoms, hidden_channels, 3, 3) return X diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index a0d3e403..297f4e1e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -209,7 +209,7 @@ def __init__( self.use_periodic = True if self.box is None: self.use_periodic = False - self.box = torch.empty((0, 0)) + self.box = torch.empty((0, 0), device="cpu", dtype=torch.float32) if self.strategy == "cell": # Default the box to 3 times the cutoff, really inefficient for the cell list lbox = cutoff_upper * 3.0 @@ -255,9 +255,10 @@ def forward( use_periodic = self.use_periodic if not use_periodic: use_periodic = box is not None + self.box = self.box.to(pos.device) box = self.box if box is None else box assert box is not None, "Box must be provided" - box = box.to(pos.dtype) + # box = box.to(pos.dtype) max_pairs: int = self.max_num_pairs if self.max_num_pairs < 0: max_pairs = -self.max_num_pairs * pos.shape[0] @@ -618,3 +619,48 @@ def scatter( } dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64} + + +# Can be globally disabled by setting the global variable ENABLE_NVTX to False +class nvtx_range: + def __init__(self, name, force_enabled=False): + self.name = name + self.force_enabled = force_enabled + + def __enter__(self): + if self.force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_push(self.name) + + def __exit__(self, type, value, traceback): + if self.force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +ENABLE_NVTX = False + + +def tmdnet_push_range(name: str, force_enabled: bool = False): + if force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_push(name) + + +def tmdnet_pop_range(force_enabled: bool = False): + if force_enabled or ENABLE_NVTX: + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +def nvtx_annotate(tag: Optional[str] = None): + def Inner(foo): + def wrapper(*args, **kwargs): + if not ENABLE_NVTX: + return foo(*args, **kwargs) + with nvtx_range(foo.__name__ if tag is None else tag): + return foo(*args, **kwargs) + + return wrapper + + return Inner