Skip to content

Commit

Permalink
Add requirements and test to readme
Browse files Browse the repository at this point in the history
  • Loading branch information
smsharma committed May 30, 2024
1 parent 57ec01c commit 66592b6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 33 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,18 @@ Additionally, the following non-equivariant models are implemented:
- [DiffPool](./models/diffpool.py) ([Ying et al 2018](https://arxiv.org/abs/1806.08804))
- [Set Transformer](./models/transformer.py) ([Lee et al 2019](https://arxiv.org/abs/1810.00825))

## Requirements
## Requirements and tests

To install requirements:
```
pip install -r requirements.txt
```

To run tests (testing equivariance and periodic boundary conditions):
```
cd tests
pytest .
```

## Basic usage and examples

Expand Down
56 changes: 24 additions & 32 deletions notebooks/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,7 +57,16 @@
"position_features = True\n",
"r_max = 0.6\n",
"use_3d_distances = False\n",
"l_max = 1"
"l_max = 1\n",
"\n",
"graph = build_graph(x_points, \n",
" None, \n",
" k=k, \n",
" use_edges=True, \n",
" n_radial_basis=n_radial,\n",
" r_max=r_max,\n",
" use_3d_distances=use_3d_distances,\n",
")"
]
},
{
Expand All @@ -69,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -97,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 68,
"metadata": {},
"outputs": [
{
Expand All @@ -114,25 +123,14 @@
" [ 0.11992944, -0.09075886]], dtype=float32)"
]
},
"execution_count": 56,
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"graph = build_graph(x_points, \n",
" None, \n",
" k=k, \n",
" use_edges=True, \n",
" n_radial_basis=n_radial,\n",
" r_max=r_max,\n",
" use_3d_distances=use_3d_distances,\n",
")\n",
"\n",
"model = GraphWrapperGNN(GNN_PARAMS)\n",
"out, params = model.init_with_output(jax.random.PRNGKey(0), graph)\n",
"\n",
"# Number of parameters\n",
"print(f\"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}\")\n",
"\n",
"out"
Expand All @@ -147,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -205,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 70,
"metadata": {},
"outputs": [
{
Expand All @@ -222,16 +220,14 @@
" [-0.00483601, -0.00015454]], dtype=float32)"
]
},
"execution_count": 58,
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = GraphWrapperSEGNN(SEGNN_PARAMS, )\n",
"out, params = model.init_with_output(jax.random.PRNGKey(0), graph)\n",
"\n",
"# Number of parameters\n",
"print(f\"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}\")\n",
"\n",
"out"
Expand All @@ -246,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -278,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 72,
"metadata": {},
"outputs": [
{
Expand All @@ -295,16 +291,14 @@
" [ -5.1505995, -9.6540985]], dtype=float32)"
]
},
"execution_count": 62,
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = GraphWrapperNequIP(NEQUIP_PARAMS)\n",
"out, params = model.init_with_output(jax.random.PRNGKey(0), graph)\n",
"\n",
"# Number of parameters\n",
"print(f\"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}\")\n",
"\n",
"out"
Expand All @@ -319,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -332,7 +326,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 74,
"metadata": {},
"outputs": [
{
Expand All @@ -349,16 +343,14 @@
" [-0.0085367 , -0.09540764]], dtype=float32)"
]
},
"execution_count": 64,
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = GraphWrapperEGNN({})\n",
"out, params = model.init_with_output(jax.random.PRNGKey(0), graph)\n",
"\n",
"# Number of parameters\n",
"print(f\"Number of parameters: {sum([p.size for p in jax.tree.leaves(params)])}\")\n",
"\n",
"out"
Expand Down

0 comments on commit 66592b6

Please sign in to comment.