Skip to content

Commit

Permalink
Update README.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaminow committed Jan 18, 2024
1 parent 3cae786 commit 4f193e8
Showing 1 changed file with 26 additions and 89 deletions.
115 changes: 26 additions & 89 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,109 +12,46 @@ Modular Training and Evaluation of Neural Networks
Copyright (c) 2022, Benjamin Kaminow

### Minimal usage example
A minimal example of how to use this package (with some random data).
Building models should be done using the `mtenn.config` API.
A small example for a SchNet model is shown below, but more details for SchNet and other models can be found in the respective class definitions.

First generate the data
We will construct a SchNet model with default parameters and a delta G strategy for combining our complex, protein, and ligand representations.
We will leave our predictions in the returned implicit kT units (ie no Readout block).
```python
import torch

## Complex data
## z: random integers (0-10)
## pos: random position floats
z_comp = torch.randint(10, (10,))
pos_comp = torch.rand((10,3))
from mtenn.config import SchNetModelConfig

## Protein data
z_prot = z_comp[:7]
pos_prot = pos_comp[:7,:]
# Create the config using all default parameters (which includes the delta G strategy)
model_config = SchNetModelConfig()

## Ligand data
z_lig = z_comp[7:]
pos_lig = pos_comp[7:,:]
# Build the actual pytorch model
model = model.build()
```

Construct the ```mtenn``` SchNet models
```python
from mtenn.conversion_utils import SchNet

## Generate an instance of the mtenn SchNet model
m = SchNet()
The input passed to this model should be a `dict` with the following keys (based on the underlying model):
* `SchNet`
* `z`: Tensor of atomic number for each atom, shape of `(n,)`
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `E3NN`
* `x`: Tensor of one-hot encodings of element for each atom, shape of `(n,one_hot_length)`
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)`
* `GAT`
* `g`: DGL graph object

## Use that model to construct a Model object using the delta strategy and one
## using the concat strategy
delta_model = SchNet.get_model(model=m, strategy='delta')
concat_model = SchNet.get_model(model=m, strategy='concat')
```

Rearrange data to pass to ```mtenn```
The prediction can then be generated simply with:
```python
## Our SchNet models take a tuple of (atomic_numbers, positions)

## Complex representation
rep_comp = (z_comp, pos_comp)

## Protein representation
rep_prot = (z_prot, pos_prot)

## Ligand representation
rep_lig = (z_lig, pos_lig)
```
import torch

Calculate energies using the different models
```python
## First predict energies using the vanilla SchNet model
e_comp = m(rep_comp)
e_prot = m(rep_prot)
e_lig = m(rep_lig)
## Calculate delta energy using
delta_e_og = e_comp - (e_prot + e_lig)

## Use the mtenn Model object to directly predict the same delta energy with
delta_e_new = delta_model(rep_comp, rep_prot, rep_lig)
# won't be exactly equal bc floating point inaccuracy
assert torch.isclose(delta_e_og, delta_e_new)

## Use the concat Model to predict delta energy (this will be different from
## the other predicted energies)
concat_e = concat_model(rep_comp, rep_prot, rep_lig)

print(f'Using vanilla SchNet model: {delta_e_og.item():0.5f}')
print(f'Using delta Model: {delta_e_new.item():0.5f}')
print(f'Using concat Model: {concat_e.item():0.5f}')
# Using random data just for demonstration purposes
pose = {"z": torch.randint(low=1, high=17, size=(100,)), "pos": torch.rand((100, 3))}
pred = model(pose)
```

### Installation

Installation of pytorch is required. We provide a minimal environment file to install pytorch and conda dependencies for conda users in ```environment.yml```. You should create a new environment
like so:

```bash
conda env create --file environment.yml
conda activate mtenn
```

Training and inference is often faster using the GPU version of pytorch.
We provide a minimal environment file to install GPU versions of pytorch and conda dependencies for conda users in ```environment-gpu.yml```. You should create a new environment like so:

```bash
conda env create --file environment-gpu.yml
conda activate mtenn-gpu
```

If not using conda an installation of pytorch **MUST** be done before installing the package itself. See the pytorch documentation on how to best install pytorch for your system.

To install mtenn and its dependencies (excluding pytorch), run
```bash
pip install -e .
```

If not using conda some of the `mtenn` dependencies do not come with pre-built wheels for all platforms, so pip may need to build them from source. This requires a C++ compiler and may take a while.
For advanced users you can directly install the dependency packages listed in the `requirements.txt` using manually specified wheels you can find on `https://data.pyg.org/whl/`.

If compatibility is proving difficult you may need to purge your pip cache

`mtenn` is now on `conda-forge`! To install, simply run
```bash
pip cache purge
mamba install -c conda-forge mtenn
```


Expand Down

0 comments on commit 4f193e8

Please sign in to comment.