Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic python api #31

Merged
merged 22 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
lcov.info

.DS_Store
snippets*
# error data files
Expand Down
82 changes: 68 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,16 @@ Implementation of [STAC](https://ieeexplore.ieee.org/document/7030016) using [MJ

## Installation
stac-mjx relies on many prerequisites, therefore we suggest installing in a new conda environment, using the provided `environment.yaml`:

Create and activate the `stac-mjx-env` environment:
[Local install before official package publish]
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
1. Clone the repository `git clone https://github.com/talmolab/stac-mjx.git` and `cd` into it
2. Create and activate the `stac-mjx-env` environment:
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved

```
conda env create -f environment.yaml
conda activate stac-mjx-env
```

## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. For new data, first run stac on just a small subset of the data with

`python stac_mjx/main.py test.skip_transform=True`

Note: this currently will fail w/o supplying a data file.

3. Render the resulting data using `mujoco_viz()` from within `viz_usage.ipynb`. Currently, this uses headless rendering on CPU via `osmesa`, which requires its own setup. To set up (currently on supported on Linux), execute the following commands sequentially:
Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`. We show `osmesa` setup as it supports headless rendering which is common in remote/cluster setups. To set up (currently on supported on Linux), execute the following commands sequentially:
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved
```
sudo apt-get install libglfw3 libglew2.0 libgl1-mesa-glx libosmesa6
conda install -c conda-forge glew
Expand All @@ -39,5 +31,67 @@ conda activate stac-mjx-env
python -m ipykernel install --user --name stac-mjx-env --display-name "Python (stac-mjx-env)"
```

4. After tuning parameters and confirming the small clip is processed well, run through the whole thing with
`python stac-mjx/main.py`

## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script (also found in `docs/use_api.ipynb`).
TODO: Use our dataloaders in this example
charles-zhng marked this conversation as resolved.
Show resolved Hide resolved

```python
from stac_mjx import main
from stac_mjx import utils
from jax import numpy as jp
import numpy as np

stac_config_path = "../configs/stac.yaml"
model_config_path = "../configs/rodent.yaml"

cfg = main.load_configs(stac_config_path, model_config_path)
data_path = "./save_data_avg.mat"
# Set up mocap data
kp_names = utils.params["KP_NAMES"]
# argsort returns the indices that would sort the array
stac_keypoint_order = np.argsort(kp_names)
data_path = cfg.paths.data_path

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

# Preparing data by reordering and reshaping
# Resulting kp_data is of shape (n_frames, n_keypoints)
kp_data = jp.array(kp_data[:, :, stac_keypoint_order])
kp_data = jp.transpose(kp_data, (0, 2, 1))
kp_data = jp.reshape(kp_data, (kp_data.shape[0], -1))

# Run stac
main.run_stac(cfg, kp_data)
```

3. Render the resulting data using `mujoco_viz()` (example notebook found in `docs/viz_usage.ipynb`):
```python
import os
import mediapy as media

from stac_mjx.viz import mujoco_viz
from stac_mjx import main
from stac_mjx import utils

stac_config_path = "../configs/stac.yaml"
model_config_path = "../configs/rodent.yaml"

cfg = main.load_configs(stac_config_path, model_config_path)

rat_xml = "./models/rodent.xml"
data_path = "./transform.p"
n_frames=250
save_path="./videos/direct_render.mp4"

# Call mujoco_viz
frames = mujoco_viz(data_path, rat_xml, n_frames, save_path, start_frame=0)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=utils.params["RENDER_FPS"])
```

4. If the rendering is poor, it's likely that some hyperparameter tuning is necessary. (details WIP)
19 changes: 10 additions & 9 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Frames per clip for transform.
N_FRAMES_PER_CLIP: 360
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
jf514 marked this conversation as resolved.
Show resolved Hide resolved
# FTOL: 5.0e-03
# ROOT_FTOL: 1.0e-05
# LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -171,14 +180,6 @@ RENDER_FPS: 50

N_SAMPLE_FRAMES: 100

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 1.0e-02
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
Expand Down
15 changes: 5 additions & 10 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
paths:
model_config: "rodent"
xml: "././models/rodent.xml"
fit_path: "fit_sq.p"
transform_path: "transform_sq.p"
xml: "./models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "./tests/data/test_pred_only_1000_frames.mat"

n_fit_frames: 1000
sampler: "first" # first, every, or random
first_start: 0 # starting frame for "first" sampler

# Should this be included?
test:
skip_fit: False
skip_transform: False
skip_fit: False
skip_transform: True

mujoco:
solver: "newton"
Expand Down
Empty file added conftest.py
Empty file.
Loading