Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic python api #31

Merged
merged 22 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
lcov.info

.DS_Store
snippets*
# error data files
Expand Down
19 changes: 10 additions & 9 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Frames per clip for transform.
N_FRAMES_PER_CLIP: 360
N_FRAMES_PER_CLIP: 250

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
jf514 marked this conversation as resolved.
Show resolved Hide resolved
# FTOL: 5.0e-03
# ROOT_FTOL: 1.0e-05
# LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -171,14 +180,6 @@ RENDER_FPS: 50

N_SAMPLE_FRAMES: 100

# Tolerance for the optimizations of the full model, limb, and root.
FTOL: 1.0e-02
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
Expand Down
15 changes: 5 additions & 10 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
paths:
model_config: "rodent"
xml: "././models/rodent.xml"
fit_path: "fit_sq.p"
transform_path: "transform_sq.p"
xml: "./models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "./tests/data/test_pred_only_1000_frames.mat"

n_fit_frames: 1000
sampler: "first" # first, every, or random
first_start: 0 # starting frame for "first" sampler

# Should this be included?
test:
skip_fit: False
skip_transform: False
skip_fit: False
skip_transform: True

mujoco:
solver: "newton"
Expand Down
Empty file added conftest.py
Empty file.
287 changes: 287 additions & 0 deletions docs/use_api.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from stac_mjx import main\n",
"from stac_mjx import utils\n",
"from jax import numpy as jp\n",
"import numpy as np\n",
"\n",
"stac_config_path = \"../configs/stac.yaml\"\n",
"model_config_path = \"../configs/rodent.yaml\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load configs"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"cfg = main.load_configs(stac_config_path, model_config_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare your data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data_path = \"./save_data_avg.mat\"\n",
"# Set up mocap data\n",
"kp_names = utils.params[\"KP_NAMES\"]\n",
"# argsort returns the indices that would sort the array\n",
"stac_keypoint_order = np.argsort(kp_names)\n",
"data_path = cfg.paths.data_path\n",
"\n",
"# Load kp_data, /1000 to scale data (from mm to meters)\n",
"kp_data = utils.loadmat(data_path)[\"pred\"][:] / 1000\n",
"\n",
"# Preparing data by reordering and reshaping (TODO: will this stay the same?)\n",
"# Resulting kp_data is of shape (n_frames, n_keypoints)\n",
"kp_data = jp.array(kp_data[:, :, stac_keypoint_order])\n",
"kp_data = jp.transpose(kp_data, (0, 2, 1))\n",
"kp_data = jp.reshape(kp_data, (kp_data.shape[0], -1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run stac"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Root Optimization:\n",
"q_opt 1 finished in 148.1901571750641 with an error of 0.0008171782246790826\n",
"Replace 1 finished in 20.782493114471436\n",
"starting q_opt 2\n",
"q_opt 1 finished in 0.043703317642211914 with an error of 0.0005677181179635227\n",
"Replace 2 finished in 0.003937721252441406\n",
"Root optimization finished in 169.30976676940918\n",
"Calibration iteration: 1/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.687406063079834\n",
"Frame 1 done in 0.993701696395874 with a final error of 0.00024175415455829352\n",
"Frame 2 done in 0.7293612957000732 with a final error of 0.00023023398534860462\n",
"Frame 3 done in 0.7321372032165527 with a final error of 0.0002650381939020008\n",
"Frame 4 done in 0.7312824726104736 with a final error of 0.00014305066724773496\n",
"Frame 5 done in 0.7314968109130859 with a final error of 0.00013442340423353016\n",
"Frame 6 done in 0.729010820388794 with a final error of 0.00014300849579740316\n",
"Frame 7 done in 0.7290098667144775 with a final error of 0.00013562907406594604\n",
"Frame 8 done in 0.7302534580230713 with a final error of 0.00016131771553773433\n",
"Frame 9 done in 0.7276926040649414 with a final error of 0.0001525822008261457\n",
"Frame 10 done in 0.7264423370361328 with a final error of 0.00010605651186779141\n",
"Flattened array shape: (10,)\n",
"Mean: 0.00017130945343524218\n",
"Standard deviation: 5.114811938256025e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/charles/miniconda3/envs/stac-mjx/lib/python3.12/site-packages/jaxopt/_src/optax_wrapper.py:120: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
" return jax.tree_map(update_fun, params, updates)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final error of 0.0008072515483945608\n",
"offset optimization finished in 55.5665762424469\n",
"Calibration iteration: 2/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.4473512172698975\n",
"Frame 1 done in 0.749556303024292 with a final error of 9.767073061084375e-05\n",
"Frame 2 done in 0.7352626323699951 with a final error of 8.220987365348265e-05\n",
"Frame 3 done in 0.7272822856903076 with a final error of 9.601547208148986e-05\n",
"Frame 4 done in 0.7256743907928467 with a final error of 7.642318087164313e-05\n",
"Frame 5 done in 0.7421412467956543 with a final error of 5.553064329433255e-05\n",
"Frame 6 done in 0.749272346496582 with a final error of 5.4576499678660184e-05\n",
"Frame 7 done in 0.757559061050415 with a final error of 9.16418430279009e-05\n",
"Frame 8 done in 0.7577402591705322 with a final error of 0.00010642935376381502\n",
"Frame 9 done in 0.748542070388794 with a final error of 7.401553739327937e-05\n",
"Frame 10 done in 0.7524290084838867 with a final error of 6.143632344901562e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 7.959494541864842e-05\n",
"Standard deviation: 1.744808650983032e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n",
"Final error of 0.0006800534902140498\n",
"offset optimization finished in 0.40744948387145996\n",
"Calibration iteration: 3/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.439974546432495\n",
"Frame 1 done in 0.7551703453063965 with a final error of 0.00010688685142667964\n",
"Frame 2 done in 0.7564406394958496 with a final error of 3.6775491025764495e-05\n",
"Frame 3 done in 0.7530312538146973 with a final error of 0.00011664371413644403\n",
"Frame 4 done in 0.765972375869751 with a final error of 6.304767157416791e-05\n",
"Frame 5 done in 0.7410047054290771 with a final error of 8.429792069364339e-05\n",
"Frame 6 done in 0.7357532978057861 with a final error of 6.2231243646238e-05\n",
"Frame 7 done in 0.7366819381713867 with a final error of 9.646685066400096e-05\n",
"Frame 8 done in 0.7400796413421631 with a final error of 7.76087908889167e-05\n",
"Frame 9 done in 0.7264938354492188 with a final error of 7.353506953222677e-05\n",
"Frame 10 done in 0.7278203964233398 with a final error of 6.625505920965225e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 7.837486191419885e-05\n",
"Standard deviation: 2.244278948637657e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n",
"Final error of 0.0007646752637811005\n",
"offset optimization finished in 0.39266467094421387\n",
"Calibration iteration: 4/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.296305894851685\n",
"Frame 1 done in 0.7274603843688965 with a final error of 0.00010297569679096341\n",
"Frame 2 done in 0.7252860069274902 with a final error of 8.764992526266724e-05\n",
"Frame 3 done in 0.7346749305725098 with a final error of 0.0001309439103351906\n",
"Frame 4 done in 0.734161376953125 with a final error of 6.789484177716076e-05\n",
"Frame 5 done in 0.7315793037414551 with a final error of 6.769385072402656e-05\n",
"Frame 6 done in 0.7343993186950684 with a final error of 6.748946907464415e-05\n",
"Frame 7 done in 0.7363507747650146 with a final error of 9.585278894519433e-05\n",
"Frame 8 done in 0.7283594608306885 with a final error of 6.904124893480912e-05\n",
"Frame 9 done in 0.721167802810669 with a final error of 6.818402471253648e-05\n",
"Frame 10 done in 0.7211976051330566 with a final error of 3.0830520699964836e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 7.885562808951363e-05\n",
"Standard deviation: 2.559636595833581e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n",
"Final error of 0.0008837199420668185\n",
"offset optimization finished in 0.3865945339202881\n",
"Calibration iteration: 5/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.228006601333618\n",
"Frame 1 done in 0.7204337120056152 with a final error of 0.0001364046474918723\n",
"Frame 2 done in 0.7182929515838623 with a final error of 8.813106251182035e-05\n",
"Frame 3 done in 0.7210404872894287 with a final error of 0.0001402536581736058\n",
"Frame 4 done in 0.7190110683441162 with a final error of 6.626713002333418e-05\n",
"Frame 5 done in 0.7206106185913086 with a final error of 6.728066364303231e-05\n",
"Frame 6 done in 0.7251994609832764 with a final error of 7.111558079486713e-05\n",
"Frame 7 done in 0.7227449417114258 with a final error of 9.620625496609136e-05\n",
"Frame 8 done in 0.7219099998474121 with a final error of 6.149008549982682e-05\n",
"Frame 9 done in 0.7217881679534912 with a final error of 6.657913036178797e-05\n",
"Frame 10 done in 0.7354505062103271 with a final error of 3.0597089789807796e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 8.243253250839189e-05\n",
"Standard deviation: 3.2363965146942064e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n",
"Final error of 0.0009675284381955862\n",
"offset optimization finished in 0.38419532775878906\n",
"Calibration iteration: 6/6\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.329638957977295\n",
"Frame 1 done in 0.7392983436584473 with a final error of 0.00013529157149605453\n",
"Frame 2 done in 0.7381787300109863 with a final error of 8.770832209847867e-05\n",
"Frame 3 done in 0.7381834983825684 with a final error of 0.00014667848881799728\n",
"Frame 4 done in 0.7422003746032715 with a final error of 6.472377572208643e-05\n",
"Frame 5 done in 0.7343780994415283 with a final error of 6.871286313980818e-05\n",
"Frame 6 done in 0.725665807723999 with a final error of 0.00010612722689984366\n",
"Frame 7 done in 0.7254834175109863 with a final error of 7.647361053386703e-05\n",
"Frame 8 done in 0.7276837825775146 with a final error of 5.803934982395731e-05\n",
"Frame 9 done in 0.7250556945800781 with a final error of 6.393478543031961e-05\n",
"Frame 10 done in 0.7320175170898438 with a final error of 3.011138687725179e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 8.378014172194526e-05\n",
"Standard deviation: 3.421222936594859e-05\n",
"starting offset optimization\n",
"Begining offset optimization:\n",
"Final error of 0.0007739891298115253\n",
"offset optimization finished in 0.38109660148620605\n",
"Final pose optimization\n",
"Pose Optimization:\n",
"Pose Optimization done in 7.450073719024658\n",
"Frame 1 done in 0.7415714263916016 with a final error of 0.00013466487871482968\n",
"Frame 2 done in 0.7458879947662354 with a final error of 8.899954264052212e-05\n",
"Frame 3 done in 0.748225212097168 with a final error of 0.00015047925990074873\n",
"Frame 4 done in 0.7425541877746582 with a final error of 6.72857349854894e-05\n",
"Frame 5 done in 0.7349767684936523 with a final error of 6.710427987854928e-05\n",
"Frame 6 done in 0.7367062568664551 with a final error of 0.00010782179742818698\n",
"Frame 7 done in 0.7347347736358643 with a final error of 7.660775736439973e-05\n",
"Frame 8 done in 0.7530131340026855 with a final error of 5.502960993908346e-05\n",
"Frame 9 done in 0.7635860443115234 with a final error of 6.209456478245556e-05\n",
"Frame 10 done in 0.7473659515380859 with a final error of 2.9612450816784985e-05\n",
"Flattened array shape: (10,)\n",
"Mean: 8.396997873205692e-05\n",
"Standard deviation: 3.532067785272375e-05\n",
"shape of qpos: (10, 74)\n"
]
},
{
"data": {
"text/plain": [
"('fit.p', 'No transform path')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"main.run_stac(cfg, kp_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "stac-mjx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
File renamed without changes.
Loading