Skip to content

Commit

Permalink
Merge pull request #38 from choderalab/update-readme
Browse files Browse the repository at this point in the history
Update README.
kaminow authored Jan 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 3cae786 + 4f193e8 commit ee8a09a
Showing 1 changed file with 26 additions and 89 deletions.
115 changes: 26 additions & 89 deletions README.md
Original file line number Diff line number Diff line change
@@ -12,109 +12,46 @@ Modular Training and Evaluation of Neural Networks
Copyright (c) 2022, Benjamin Kaminow

### Minimal usage example
A minimal example of how to use this package (with some random data).
Building models should be done using the `mtenn.config` API.
A small example for a SchNet model is shown below, but more details for SchNet and other models can be found in the respective class definitions.

First generate the data
We will construct a SchNet model with default parameters and a delta G strategy for combining our complex, protein, and ligand representations.
We will leave our predictions in the returned implicit kT units (ie no Readout block).
```python
import torch

## Complex data
## z: random integers (0-10)
## pos: random position floats
z_comp = torch.randint(10, (10,))
pos_comp = torch.rand((10,3))
from mtenn.config import SchNetModelConfig

## Protein data
z_prot = z_comp[:7]
pos_prot = pos_comp[:7,:]
# Create the config using all default parameters (which includes the delta G strategy)
model_config = SchNetModelConfig()

## Ligand data
z_lig = z_comp[7:]
pos_lig = pos_comp[7:,:]
# Build the actual pytorch model
model = model.build()
```

Construct the ```mtenn``` SchNet models
```python
from mtenn.conversion_utils import SchNet

## Generate an instance of the mtenn SchNet model
m = SchNet()
The input passed to this model should be a `dict` with the following keys (based on the underlying model):
* `SchNet`
* `z`: Tensor of atomic number for each atom, shape of `(n,)`
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `E3NN`
* `x`: Tensor of one-hot encodings of element for each atom, shape of `(n,one_hot_length)`
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)`
* `GAT`
* `g`: DGL graph object

## Use that model to construct a Model object using the delta strategy and one
## using the concat strategy
delta_model = SchNet.get_model(model=m, strategy='delta')
concat_model = SchNet.get_model(model=m, strategy='concat')
```

Rearrange data to pass to ```mtenn```
The prediction can then be generated simply with:
```python
## Our SchNet models take a tuple of (atomic_numbers, positions)

## Complex representation
rep_comp = (z_comp, pos_comp)

## Protein representation
rep_prot = (z_prot, pos_prot)

## Ligand representation
rep_lig = (z_lig, pos_lig)
```
import torch

Calculate energies using the different models
```python
## First predict energies using the vanilla SchNet model
e_comp = m(rep_comp)
e_prot = m(rep_prot)
e_lig = m(rep_lig)
## Calculate delta energy using
delta_e_og = e_comp - (e_prot + e_lig)

## Use the mtenn Model object to directly predict the same delta energy with
delta_e_new = delta_model(rep_comp, rep_prot, rep_lig)
# won't be exactly equal bc floating point inaccuracy
assert torch.isclose(delta_e_og, delta_e_new)

## Use the concat Model to predict delta energy (this will be different from
## the other predicted energies)
concat_e = concat_model(rep_comp, rep_prot, rep_lig)

print(f'Using vanilla SchNet model: {delta_e_og.item():0.5f}')
print(f'Using delta Model: {delta_e_new.item():0.5f}')
print(f'Using concat Model: {concat_e.item():0.5f}')
# Using random data just for demonstration purposes
pose = {"z": torch.randint(low=1, high=17, size=(100,)), "pos": torch.rand((100, 3))}
pred = model(pose)
```

### Installation

Installation of pytorch is required. We provide a minimal environment file to install pytorch and conda dependencies for conda users in ```environment.yml```. You should create a new environment
like so:

```bash
conda env create --file environment.yml
conda activate mtenn
```

Training and inference is often faster using the GPU version of pytorch.
We provide a minimal environment file to install GPU versions of pytorch and conda dependencies for conda users in ```environment-gpu.yml```. You should create a new environment like so:

```bash
conda env create --file environment-gpu.yml
conda activate mtenn-gpu
```

If not using conda an installation of pytorch **MUST** be done before installing the package itself. See the pytorch documentation on how to best install pytorch for your system.

To install mtenn and its dependencies (excluding pytorch), run
```bash
pip install -e .
```

If not using conda some of the `mtenn` dependencies do not come with pre-built wheels for all platforms, so pip may need to build them from source. This requires a C++ compiler and may take a while.
For advanced users you can directly install the dependency packages listed in the `requirements.txt` using manually specified wheels you can find on `https://data.pyg.org/whl/`.

If compatibility is proving difficult you may need to purge your pip cache

`mtenn` is now on `conda-forge`! To install, simply run
```bash
pip cache purge
mamba install -c conda-forge mtenn
```


0 comments on commit ee8a09a

Please sign in to comment.