-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save output data as h5 #84
Conversation
WalkthroughThis pull request introduces significant changes to the Stac-MJX library, focusing on data handling, configuration management, and utility functions. The modifications span across multiple files, with key updates including the removal of Changes
Possibly related PRs
Suggested reviewers
Poem
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #84 +/- ##
==========================================
+ Coverage 39.86% 44.26% +4.40%
==========================================
Files 10 9 -1
Lines 582 689 +107
==========================================
+ Hits 232 305 +73
- Misses 350 384 +34 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🔭 Outside diff range comments (3)
stac_mjx/io.py (1)
Line range hint
113-115
: Update error message to include '.h5' as a supported file extensionThe error message currently indicates that only '.nwb' or '.mat' files are supported, but the code also supports '.h5' files. Please update the error message to reflect this.
Apply this diff to fix the error message:
raise ValueError( - "Unsupported file extension. Please provide a .nwb or .mat file." + "Unsupported file extension. Please provide a .nwb, .h5, or .mat file." )stac_mjx/viz.py (1)
Line range hint
55-65
: Breaking change: Function return signature has changed.The function now returns a tuple
(cfg, frames)
instead of justframes
. This change could break existing callers of the function.Consider one of these approaches:
- Document the breaking change and update all callers
- Keep the original return signature and only return the frames
configs/stac/stac_fly_tethered.yaml (1)
Line range hint
1-2
: Standardize output file formats to use .h5.While ik_only_path uses .h5, fit_offsets_path still uses .p (pickle) format. Consider standardizing all outputs to use h5 format as per PR objective.
🧹 Nitpick comments (11)
stac_mjx/viz.py (1)
8-9
: Clean up unused imports.The following imports are not used in the code:
DictConfig
fromomegaconf
Dict
fromtyping
-from omegaconf import DictConfig -from typing import Union, Dict +from typing import Union🧰 Tools
🪛 Ruff (0.8.2)
8-8:
omegaconf.DictConfig
imported but unusedRemove unused import:
omegaconf.DictConfig
(F401)
9-9:
typing.Dict
imported but unusedRemove unused import:
typing.Dict
(F401)
stac_mjx/main.py (2)
6-8
: Remove unused imports.The following imports are not used in the code:
pickle
logging
-import pickle -import logging🧰 Tools
🪛 Ruff (0.8.2)
6-6:
pickle
imported but unusedRemove unused import:
pickle
(F401)
8-8:
logging
imported but unusedRemove unused import:
logging
(F401)
89-89
: Remove useless expression.The tuple expression on line 89 is not being used.
- (fit_offsets_data, fit_offsets_path)
🧰 Tools
🪛 Ruff (0.8.2)
89-89: Found useless expression. Either assign it to a variable or remove it.
(B018)
stac_mjx/compute_stac.py (1)
5-5
: Remove unused import.The
numpy
import is not used in the code.-import numpy as np
🧰 Tools
🪛 Ruff (0.8.2)
5-5:
numpy
imported but unusedRemove unused import:
numpy
(F401)
stac_mjx/utils.py (1)
199-209
: Uncomment and fix the error handling in_clip_within_precision
.The error handling is commented out due to JIT compilation issues. Consider using
jax.debug.print
for debugging or implement a different error handling strategy.Apply this diff to improve error handling:
- # This is raising an error when jitted - # def _raise_if_not_in_precision(): - # if (number < low - precision).any() or (number > high + precision).any(): - # raise ValueError( - # "Input {:.12f} not inside range [{:.12f}, {:.12f}] with precision {}".format( - # number, low, high, precision - # ) - # ) - - # jax.debug.callback(_raise_if_not_in_precision) + # Use jax.debug.print for debugging in JIT-compiled functions + def _check_precision(): + is_below = (number < low - precision).any() + is_above = (number > high + precision).any() + if is_below or is_above: + jax.debug.print( + "Warning: Input {x} not inside range [{l}, {h}] with precision {p}", + x=number, + l=low, + h=high, + p=precision + ) + jax.lax.cond( + (number < low - precision).any() | (number > high + precision).any(), + lambda _: _check_precision(), + lambda _: None, + operand=None + )stac_mjx/stac.py (1)
200-207
: Improve error stats type handling.The function now accepts a list but uses numpy array operations. Consider using JAX's numpy (
jp
) for consistency with the rest of the codebase.Apply this diff to use JAX's numpy:
- flattened_errors = np.array(errors).reshape(-1) - mean = np.mean(flattened_errors) - std = np.std(flattened_errors) + flattened_errors = jp.array(errors).reshape(-1) + mean = jp.mean(flattened_errors) + std = jp.std(flattened_errors)tests/configs/stac/test_stac.yaml (1)
8-9
: Fix trailing space and consider adding parameter documentation.The new parameters look good for testing purposes, but there are a few minor improvements to make:
- Remove the trailing space after
infer_qvels: False
- Consider adding comments explaining these parameters (similar to other config files)
-infer_qvels: False +infer_qvels: False # Infer qvels from stac output n_frames_per_clip: 1 # Number of frames per clip for processing🧰 Tools
🪛 yamllint (1.35.1)
[error] 8-8: trailing spaces
(trailing-spaces)
configs/stac/stac.yaml (1)
8-9
: Add parameter documentation for consistency.The parameter values look good, but for consistency with other config files (e.g., demo.yaml), consider adding explanatory comments.
-infer_qvels: True +infer_qvels: True # Infer qvels from stac output -n_frames_per_clip: 250 +n_frames_per_clip: 250 # Number of frames per clip for processing🧰 Tools
🪛 yamllint (1.35.1)
[error] 8-8: trailing spaces
(trailing-spaces)
configs/stac/stac_mouse.yaml (1)
8-9
: Add parameter documentation and fix trailing space.The parameters look appropriate for mouse data processing, but maintain documentation consistency with other config files.
-infer_qvels: True +infer_qvels: True # Infer qvels from stac output -n_frames_per_clip: 360 +n_frames_per_clip: 360 # Number of frames per clip for mouse data🧰 Tools
🪛 yamllint (1.35.1)
[error] 8-8: trailing spaces
(trailing-spaces)
configs/stac/demo.yaml (1)
1-1
: Consider creating a schema or template for configuration files.To maintain consistency across configuration files, consider:
- Creating a schema that documents all available parameters and their purposes
- Using a template that includes standardized comments for common parameters
- Adding validation for parameter values (e.g., ensuring n_frames_per_clip > 0)
This would help maintain documentation quality and prevent configuration drift across different use cases.
configs/stac/stac_fly_treadmill.yaml (1)
9-10
: Document the purpose of n_frames_per_clip parameter.This parameter seems to have been moved from model config. Consider adding a comment explaining its purpose and relationship with n_fit_frames.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
videos/direct_render.mp4
is excluded by!**/*.mp4
📒 Files selected for processing (35)
configs/config.yaml
(0 hunks)configs/model/fly_tethered.yaml
(0 hunks)configs/model/fly_treadmill.yaml
(0 hunks)configs/model/rodent.yaml
(0 hunks)configs/model/synth_data.yaml
(0 hunks)configs/stac/demo.yaml
(1 hunks)configs/stac/stac.yaml
(1 hunks)configs/stac/stac_fly_tethered.yaml
(1 hunks)configs/stac/stac_fly_treadmill.yaml
(1 hunks)configs/stac/stac_mouse.yaml
(1 hunks)configs/stac/stac_synth_data.yaml
(1 hunks)demos/api_usage.ipynb
(7 hunks)demos/viz_usage.ipynb
(4 hunks)run_stac.py
(1 hunks)stac_mjx/__init__.py
(1 hunks)stac_mjx/compute_stac.py
(9 hunks)stac_mjx/io.py
(2 hunks)stac_mjx/io_dict_to_hdf5.py
(0 hunks)stac_mjx/main.py
(4 hunks)stac_mjx/op_utils.py
(0 hunks)stac_mjx/stac.py
(14 hunks)stac_mjx/stac_core.py
(5 hunks)stac_mjx/utils.py
(1 hunks)stac_mjx/viz.py
(2 hunks)tests/configs/model/test_mouse.yaml
(0 hunks)tests/configs/model/test_rodent.yaml
(0 hunks)tests/configs/model/test_rodent_label3d.yaml
(0 hunks)tests/configs/model/test_rodent_less_kp_names.yaml
(0 hunks)tests/configs/model/test_rodent_no_kp_names.yaml
(0 hunks)tests/configs/stac/test_stac.yaml
(1 hunks)tests/integration/test_model.py
(1 hunks)tests/test_io.py
(6 hunks)tests/unit/test_controller.py
(1 hunks)tests/unit/test_main.py
(1 hunks)tests/unit/test_utils.py
(6 hunks)
💤 Files with no reviewable changes (12)
- tests/configs/model/test_rodent.yaml
- configs/model/fly_treadmill.yaml
- tests/configs/model/test_rodent_label3d.yaml
- tests/configs/model/test_rodent_no_kp_names.yaml
- tests/configs/model/test_rodent_less_kp_names.yaml
- configs/model/synth_data.yaml
- configs/config.yaml
- configs/model/rodent.yaml
- tests/configs/model/test_mouse.yaml
- configs/model/fly_tethered.yaml
- stac_mjx/io_dict_to_hdf5.py
- stac_mjx/op_utils.py
✅ Files skipped from review due to trivial changes (1)
- tests/integration/test_model.py
🧰 Additional context used
🪛 yamllint (1.35.1)
configs/stac/stac.yaml
[error] 8-8: trailing spaces
(trailing-spaces)
configs/stac/stac_mouse.yaml
[error] 8-8: trailing spaces
(trailing-spaces)
tests/configs/stac/test_stac.yaml
[error] 8-8: trailing spaces
(trailing-spaces)
🪛 Ruff (0.8.2)
stac_mjx/__init__.py
3-3: stac_mjx.utils.enable_xla_flags
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
4-4: stac_mjx.io.load_mocap
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
stac_mjx/viz.py
8-8: omegaconf.DictConfig
imported but unused
Remove unused import: omegaconf.DictConfig
(F401)
9-9: typing.Dict
imported but unused
Remove unused import: typing.Dict
(F401)
stac_mjx/main.py
6-6: pickle
imported but unused
Remove unused import: pickle
(F401)
8-8: logging
imported but unused
Remove unused import: logging
(F401)
89-89: Found useless expression. Either assign it to a variable or remove it.
(B018)
stac_mjx/compute_stac.py
5-5: numpy
imported but unused
Remove unused import: numpy
(F401)
stac_mjx/utils.py
10-10: numpy
imported but unused
Remove unused import: numpy
(F401)
11-11: stac_mjx.io
imported but unused
Remove unused import: stac_mjx.io
(F401)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (33)
stac_mjx/__init__.py (1)
3-4
: Imports updated correctly to reflect module restructuringThe import statements have been correctly updated to reflect the restructuring of utility functions and data loading methods.
🧰 Tools
🪛 Ruff (0.8.2)
3-3:
stac_mjx.utils.enable_xla_flags
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
4-4:
stac_mjx.io.load_mocap
imported but unused; consider removing, adding to__all__
, or using a redundant alias(F401)
tests/unit/test_main.py (1)
22-23
: Test assertions updated to match new configuration parametersThe test assertions have been appropriately updated to align with the new configuration parameters in the
Config
dataclass, ensuring the test remains valid.run_stac.py (1)
11-11
: LGTM! Function update aligns with codebase changes.The change from
load_data
toload_mocap
is consistent with the broader refactoring across the codebase.tests/unit/test_controller.py (1)
14-14
: LGTM! Test updated to match new API.The test has been correctly updated to use
load_mocap
while maintaining the same test logic and assertions.stac_mjx/viz.py (1)
41-46
: LGTM! Improved data loading implementation.The new implementation using
io.load_stac_data
centralizes data loading logic and provides better organization.tests/unit/test_utils.py (1)
Line range hint
34-94
: LGTM! Comprehensive test coverage maintained.All test cases have been consistently updated to use
load_mocap
while maintaining coverage for:
- Different file types (.nwb, .mat, .h5)
- Error cases (missing/insufficient keypoint names)
- Various configurations
tests/test_io.py (2)
34-34
: LGTM! Function call updated to use new load_mocap method.The change from
load_data
toload_mocap
is consistent with the refactoring across the codebase.
44-44
: Consistent update of load_data to load_mocap across all test functions.All test functions have been updated to use the new
load_mocap
method while maintaining the same test logic and assertions.Also applies to: 56-56, 69-69, 80-80, 94-94
stac_mjx/main.py (3)
Line range hint
18-37
: LGTM! Enhanced configuration loading with validation.The changes improve configuration management by:
- Adding config_name parameter for flexibility
- Adding structured config validation
- Adding user feedback through print statements
72-78
: LGTM! Added vectorized velocity computation.Good use of JAX's
vmap
for efficient batch processing of kinematic data.
86-88
: LGTM! Implemented h5 data saving.The implementation uses a dedicated method
save_data_to_h5
for consistent data storage.Also applies to: 120-120
stac_mjx/stac_core.py (2)
13-13
: LGTM! Consistent transition from op_utils to utils module.All utility function calls have been updated to use the consolidated utils module while maintaining the same functionality.
Also applies to: 54-54, 60-61, 64-64, 104-104, 150-157
199-214
: LGTM! Improved docstring clarity.The docstring for m_opt has been updated with clearer descriptions and better-formatted argument documentation.
stac_mjx/compute_stac.py (2)
70-71
: LGTM! Consistent transition from op_utils to utils module.All utility function calls have been updated to use the consolidated utils module while maintaining the same functionality.
Also applies to: 96-97, 144-144, 165-167, 170-170
201-204
: LGTM! Improved variable naming and added quaternion storage.Changes improve code clarity through:
- More descriptive variable names (qposes, xposes, marker_sites)
- Added storage for quaternion data (xquats)
Also applies to: 260-263, 271-274
stac_mjx/utils.py (5)
14-21
: LGTM! Well-documented GPU optimization.The function enables XLA flags for optimized GPU performance with clear documentation.
23-30
: LGTM! Clear model loading implementation.The function provides a clean interface for loading Mujoco models into MJX with proper documentation.
32-58
: LGTM! Efficient JIT-compiled kinematics functions.The functions are well-documented and use JAX's JIT compilation for improved performance.
60-99
: LGTM! Comprehensive site position handling.The functions provide a complete set of operations for getting and setting site positions with proper documentation.
272-313
: LGTM! Well-structured velocity computation.The function handles both free and constrained joints with proper velocity limits and clear documentation.
demos/api_usage.ipynb (3)
54-54
: LGTM! Updated function call.The function call has been updated from
load_data
toload_mocap
to reflect the API changes.
269-270
: LGTM! Clear output messages.The output messages clearly indicate the data saving path and skipped operations.
293-293
: LGTM! Updated kernel display name.The kernel display name has been updated to "stac-mjx-env" for consistency.
stac_mjx/stac.py (3)
15-15
: LGTM! Consolidated imports.The imports have been reorganized to use
utils
instead ofop_utils
.
111-111
: LGTM! Added freejoint flag.The
_freejoint
flag is correctly initialized based on the model's joint type.
433-442
: LGTM! Improved data packaging.The data packaging now uses descriptive variable names and returns a structured
StacData
object.demos/viz_usage.ipynb (3)
78-78
: LGTM! Updated data format.The data path has been updated to use the .h5 format instead of .p.
83-89
: LGTM! Improved function call formatting.The
viz_stac
function call has been formatted for better readability with one parameter per line.
105-105
: LGTM! Updated kernel display name.The kernel display name has been updated to "stac-mjx-env" for consistency.
configs/stac/stac_synth_data.yaml (1)
1-2
: Changes look good and well documented!The updates align well with the PR objective of switching to h5 format, and the parameters are appropriately documented.
Also applies to: 7-9
configs/stac/demo.yaml (1)
1-2
: Well documented changes with clear explanations!The changes look good and include helpful comments explaining the parameters. Consider reusing these well-written comments across other config files for consistency.
Also applies to: 5-5, 8-9
configs/stac/stac_fly_treadmill.yaml (1)
Line range hint
1-2
: Consider updating output paths to use .h5 extension.The PR objective is to save output data as h5, but the output paths still use pickle format (.p extension). Consider updating these paths to use .h5 extension for consistency.
configs/stac/stac_fly_tethered.yaml (1)
12-13
: Document n_frames_per_clip and verify value consistency.
- Consider adding a comment explaining the purpose of this parameter and its relationship with n_fit_frames.
- The value (300) differs from the treadmill config (581). Please verify if this difference is intentional.
✅ Verification successful
Values are intentionally different, but documentation needed
The different
n_frames_per_clip
values across configs are intentional, as this parameter controls batch processing size for parallel computation. Suggested documentation forconfigs/stac/stac_fly_tethered.yaml
:n_frames_per_clip: 300 # Number of frames to process in parallel. Must divide evenly into total mocap frames.Note: This parameter is distinct from
n_fit_frames
(601), which is used only for the initial offset fitting phase.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check n_frames_per_clip values across configs echo "Checking n_frames_per_clip values across configs:" rg -g '*.yaml' 'n_frames_per_clip:'Length of output: 545
Script:
#!/bin/bash # Search for documentation and usage of n_frames_per_clip echo "Searching for documentation and usage of n_frames_per_clip:" rg -g '!*.yaml' 'n_frames_per_clip' -B 2 -A 2 echo -e "\nSearching for n_fit_frames:" rg 'n_fit_frames' -B 2 -A 2Length of output: 4103
|
||
def viz_stac( | ||
data_path: Union[Path, str], | ||
cfg: DictConfig, | ||
# cfg: DictConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove commented parameter instead of commenting it out.
Clean up the function signature by removing the commented-out parameter entirely, as it's no longer used.
- # cfg: DictConfig,
import numpy as np | ||
from stac_mjx import io |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove unused imports.
The following imports are not used in this file:
numpy
stac_mjx.io
Apply this diff to remove the unused imports:
-import numpy as np
-from stac_mjx import io
🧰 Tools
🪛 Ruff (0.8.2)
10-10: numpy
imported but unused
Remove unused import: numpy
(F401)
11-11: stac_mjx.io
imported but unused
Remove unused import: stac_mjx.io
(F401)
def make_qs(q0, qs_to_opt, q): | ||
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt. | ||
|
||
Args: | ||
q0 (jax.Array): initial joint angles | ||
qs_to_opt (jax.Array): joint angles that were optimized | ||
q (jax.Array): new joint angles | ||
|
||
Returns: | ||
jp.Array: resulting set of joint angles | ||
""" | ||
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q)) | ||
|
||
|
||
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q): | ||
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics. | ||
|
||
Args: | ||
mjx_model (mjx.Model): | ||
mjx_data (mjx.Data): | ||
q (jax.Array): new joint angles | ||
|
||
Returns: | ||
mjx.Data: resulting mjx Data | ||
""" | ||
if q is None: | ||
print("optimization failed, continuing") | ||
|
||
else: | ||
mjx_data = mjx_data.replace(qpos=q) | ||
mjx_data = kinematics(mjx_model, mjx_data) | ||
|
||
return mjx_data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for None input in replace_qs
.
The function handles None
input by printing a message and continuing, which could mask errors. Consider raising an exception or returning early with a meaningful error message.
Apply this diff to improve error handling:
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q):
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics.
Args:
mjx_model (mjx.Model):
mjx_data (mjx.Data):
q (jax.Array): new joint angles
Returns:
mjx.Data: resulting mjx Data
+
+ Raises:
+ ValueError: If q is None, indicating optimization failure
"""
if q is None:
- print("optimization failed, continuing")
+ raise ValueError("Optimization failed: received None for joint angles")
else:
mjx_data = mjx_data.replace(qpos=q)
mjx_data = kinematics(mjx_model, mjx_data)
return mjx_data
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def make_qs(q0, qs_to_opt, q): | |
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt. | |
Args: | |
q0 (jax.Array): initial joint angles | |
qs_to_opt (jax.Array): joint angles that were optimized | |
q (jax.Array): new joint angles | |
Returns: | |
jp.Array: resulting set of joint angles | |
""" | |
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q)) | |
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q): | |
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics. | |
Args: | |
mjx_model (mjx.Model): | |
mjx_data (mjx.Data): | |
q (jax.Array): new joint angles | |
Returns: | |
mjx.Data: resulting mjx Data | |
""" | |
if q is None: | |
print("optimization failed, continuing") | |
else: | |
mjx_data = mjx_data.replace(qpos=q) | |
mjx_data = kinematics(mjx_model, mjx_data) | |
return mjx_data | |
def make_qs(q0, qs_to_opt, q): | |
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt. | |
Args: | |
q0 (jax.Array): initial joint angles | |
qs_to_opt (jax.Array): joint angles that were optimized | |
q (jax.Array): new joint angles | |
Returns: | |
jp.Array: resulting set of joint angles | |
""" | |
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q)) | |
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q): | |
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics. | |
Args: | |
mjx_model (mjx.Model): | |
mjx_data (mjx.Data): | |
q (jax.Array): new joint angles | |
Returns: | |
mjx.Data: resulting mjx Data | |
Raises: | |
ValueError: If q is None, indicating optimization failure | |
""" | |
if q is None: | |
raise ValueError("Optimization failed: received None for joint angles") | |
else: | |
mjx_data = mjx_data.replace(qpos=q) | |
mjx_data = kinematics(mjx_model, mjx_data) | |
return mjx_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
stac_mjx/utils.py (1)
10-11
: Remove unused imports.The following imports are not used in this file:
numpy
stac_mjx.io
Apply this diff to remove the unused imports:
-import numpy as np -from stac_mjx import io🧰 Tools
🪛 Ruff (0.8.2)
10-10:
numpy
imported but unusedRemove unused import:
numpy
(F401)
11-11:
stac_mjx.io
imported but unusedRemove unused import:
stac_mjx.io
(F401)
stac_mjx/io.py (2)
265-326
: Add error handling for HDF5 operations.While the HDF5 saving functions are well-structured, they could benefit from error handling for common issues like:
- File permission errors
- Disk space issues
- Invalid data types
Would you like me to provide an example implementation with proper error handling?
328-368
: Add input validation for HDF5 loading.Consider adding validation for:
- File existence and format
- Required dataset presence
- Data shape consistency
Would you like me to provide an example implementation with proper validation?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
stac_mjx/io.py
(2 hunks)stac_mjx/stac.py
(14 hunks)stac_mjx/utils.py
(1 hunks)stac_mjx/viz.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- stac_mjx/stac.py
🧰 Additional context used
🪛 Ruff (0.8.2)
stac_mjx/viz.py
8-8: omegaconf.DictConfig
imported but unused
Remove unused import: omegaconf.DictConfig
(F401)
9-9: typing.Dict
imported but unused
Remove unused import: typing.Dict
(F401)
stac_mjx/utils.py
10-10: numpy
imported but unused
Remove unused import: numpy
(F401)
11-11: stac_mjx.io
imported but unused
Remove unused import: stac_mjx.io
(F401)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (4)
stac_mjx/viz.py (2)
14-14
: Remove commented parameter instead of commenting it out.Clean up the function signature by removing the commented-out parameter entirely, as it's no longer used.
40-44
: LGTM! Good improvements to data handling.The changes improve code organization by:
- Centralizing data loading in
io.load_stac_data
.- Using structured data class for better data organization.
- Providing configuration back to caller for potential use.
Also applies to: 54-63
stac_mjx/utils.py (1)
126-128
: Add error handling for None input inreplace_qs
.The function handles
None
input by printing a message and continuing, which could mask errors. Consider raising an exception or returning early with a meaningful error message.stac_mjx/io.py (1)
23-96
: Well-structured configuration and data management using dataclasses!The dataclass hierarchy provides:
- Clear organization of configuration parameters
- Type safety through annotations
- Easy serialization/deserialization
|
||
@dataclass | ||
class ModelConfig: | ||
"""Configuration for body model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we could document what the attributes mean in this and subsequent dataclass docstrings?
stac_mjx/io.py
Outdated
kp_data = np.array(f["kp_data"]) | ||
marker_sites = np.array(f["marker_sites"]) | ||
offsets = np.array(f["offsets"]) | ||
qpos = np.array(f["qpos"]) | ||
qvel = np.array(f["qvel"]) | ||
xpos = np.array(f["xpos"]) | ||
xquat = np.array(f["xquat"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Canonical way to load the datasets into memory is like:
with h5py.File(file_path, "r") as f:
kp_data = f["kp_data"][()]
Ref: https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
Numpy array casting has the same effect by forcing the read op, but doing it the way above maintains type and works with scalars and other primitives without coercing them into arrays.
stac_mjx/io_dict_to_hdf5.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
stac_mjx/main.py
Outdated
|
||
# Stop here if not doing ik only phase | ||
if cfg.stac.skip_ik_only == 1: | ||
logging.info("skipping ik_only()") | ||
print("skipping ik_only()") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's give the users a bit more of an informative error message? Maybe:
print("skipping ik_only()") | |
print("Skipping IK-only phase. Only offsets will be returned. To change this behavior, set cfg.stac.skip_ik_only to 0.") |
|
||
Args: | ||
qpos_trajectory (jp.ndarray): trajectory of qpos values T x ? | ||
Note assumes has freejoint as the first 7 dimensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just have this as a kwarg with 7 as default? Will this change across body models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first 7 dimensions is always the freejoint if it exists; the freejoint
arg accounts for this so I'll use that when calling this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (7)
stac_mjx/main.py (1)
6-6
: Remove unused imports.The following imports are not used in the code:
pickle
logging
-import pickle import time -import logging from omegaconf import DictConfig, OmegaConfAlso applies to: 8-8
🧰 Tools
🪛 Ruff (0.8.2)
6-6:
pickle
imported but unusedRemove unused import:
pickle
(F401)
stac_mjx/viz.py (1)
8-9
: Remove unused imports.The following imports are not used in the code:
DictConfig
fromomegaconf
Dict
fromtyping
-from omegaconf import DictConfig -from typing import Union, Dict +from typing import Union🧰 Tools
🪛 Ruff (0.8.2)
8-8:
omegaconf.DictConfig
imported but unusedRemove unused import:
omegaconf.DictConfig
(F401)
9-9:
typing.Dict
imported but unusedRemove unused import:
typing.Dict
(F401)
stac_mjx/io.py (3)
269-282
: Consider error handling for HDF5 operations.The
save_dict_to_hdf5
function should handle potential HDF5 errors that may occur during group creation or attribute setting.def save_dict_to_hdf5(group, dictionary): """Save a dictionary to an HDF5 group. Args: group (h5py.Group): HDF5 group to save the dictionary to. dictionary (dict): Dictionary to save. + + Raises: + h5py.H5Error: If there's an error during HDF5 operations. """ - for key, value in dictionary.items(): - if isinstance(value, dict): - subgroup = group.create_group(key) - save_dict_to_hdf5(subgroup, value) - else: - group.attrs[key] = value + try: + for key, value in dictionary.items(): + if isinstance(value, dict): + subgroup = group.create_group(key) + save_dict_to_hdf5(subgroup, value) + else: + group.attrs[key] = value + except (TypeError, ValueError) as e: + raise ValueError(f"Failed to save dictionary to HDF5: {e}")
284-330
: Consider adding data validation.The
save_data_to_h5
function should validate input arrays for shape consistency and non-null values before saving.def save_data_to_h5(...): """Save configuration and STAC data to an HDF5 file.""" + # Validate input arrays + if not all(isinstance(arr, np.ndarray) for arr in [kp_data, marker_sites, offsets, qpos, xpos, xquat, qvel]): + raise ValueError("All data arrays must be numpy arrays") + + # Validate array shapes + if len(kp_names) != kp_data.shape[1]: + raise ValueError(f"Mismatch between kp_names length ({len(kp_names)}) and kp_data shape ({kp_data.shape})") + with h5py.File(file_path, "w") as f: # Save config as a YAML string config_yaml = OmegaConf.to_yaml(OmegaConf.structured(config))
332-372
: Consider adding file existence check.The
load_stac_data
function should verify file existence before attempting to load it.def load_stac_data(file_path) -> tuple[Config, StacData]: """Load configuration and STAC data from an HDF5 file.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"HDF5 file not found: {file_path}") + with h5py.File(file_path, "r") as f:stac_mjx/utils.py (2)
281-281
: Improve documentation clarity for freejoint dimensions.Consider adding a constant like
FREEJOINT_DIMS = 7
or enhancing the comment to explain why the first 7 dimensions correspond to the freejoint (3 for translation + 4 for quaternion rotation).
298-305
: Consider vectorizing the quaternion operations.The loop for computing
qvel_gyro
could be vectorized usingvmap
for better performance:- qvel_gyro = [] - for t in range(qpos_trajectory.shape[0] - 1): - normed_diff = quat_diff( - qpos_trajectory[t, 3:7], qpos_trajectory[t + 1, 3:7] - ) - normed_diff /= jp.linalg.norm(normed_diff) - angle = quat_to_axisangle(normed_diff) - qvel_gyro.append(angle / dt) - qvel_gyro = jp.stack(qvel_gyro) + def compute_gyro(t): + normed_diff = quat_diff( + qpos_trajectory[t, 3:7], qpos_trajectory[t + 1, 3:7] + ) + normed_diff /= jp.linalg.norm(normed_diff) + return quat_to_axisangle(normed_diff) / dt + qvel_gyro = jax.vmap(compute_gyro)(jp.arange(qpos_trajectory.shape[0] - 1))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
stac_mjx/io.py
(2 hunks)stac_mjx/main.py
(4 hunks)stac_mjx/utils.py
(1 hunks)stac_mjx/viz.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
stac_mjx/main.py
6-6: pickle
imported but unused
Remove unused import: pickle
(F401)
8-8: logging
imported but unused
Remove unused import: logging
(F401)
89-89: Found useless expression. Either assign it to a variable or remove it.
(B018)
stac_mjx/utils.py
10-10: numpy
imported but unused
Remove unused import: numpy
(F401)
11-11: stac_mjx.io
imported but unused
Remove unused import: stac_mjx.io
(F401)
stac_mjx/viz.py
8-8: omegaconf.DictConfig
imported but unused
Remove unused import: omegaconf.DictConfig
(F401)
9-9: typing.Dict
imported but unused
Remove unused import: typing.Dict
(F401)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (14)
stac_mjx/main.py (5)
18-20
: LGTM! Well-documented function signature update.The addition of the
config_name
parameter with a default value maintains backward compatibility while adding flexibility.
32-36
: LGTM! Enhanced configuration validation.Good improvements:
- Using structured configuration validation
- Adding user feedback about config loading
72-78
: LGTM! Efficient velocity computation using JAX's vmap.Good optimization using vectorized operations for computing velocities.
97-99
: Enhance error message with more information.Consider providing more informative error messages to guide users.
- "Skipping IK-only phase. To change this behavior, set cfg.stac.skip_ik_only to 0." + "Skipping IK-only phase. Only offsets will be returned. To change this behavior, set cfg.stac.skip_ik_only to 0."
86-88
: LGTM! Consistent use of H5 format for data saving.Good use of the new
save_data_to_h5
function for consistent data storage.Also applies to: 128-128
stac_mjx/viz.py (2)
40-44
: LGTM! Clean data loading implementation.The code now uses the centralized
load_stac_data
function to load both configuration and data, which is a cleaner approach than the previous implementation.
Line range hint
54-64
: LGTM! Return statement updated to include configuration.The return statement now includes the configuration object, which provides additional context alongside the rendering result.
stac_mjx/io.py (1)
23-104
: LGTM! Well-documented dataclasses.The dataclasses are well-structured and each attribute is properly documented with clear descriptions.
stac_mjx/utils.py (6)
10-11
: Remove unused imports.The imports of
numpy
andstac_mjx.io
are not used in this file.🧰 Tools
🪛 Ruff (0.8.2)
10-10:
numpy
imported but unusedRemove unused import:
numpy
(F401)
11-11:
stac_mjx.io
imported but unusedRemove unused import:
stac_mjx.io
(F401)
14-21
: LGTM! Good GPU optimization.The function correctly checks for GPU backend before enabling XLA flags for optimized performance.
115-134
: Improve error handling inreplace_qs
.The function should raise an exception for
None
input rather than printing a message and continuing.
23-58
: LGTM! Well-structured kinematics functions.The model loading and kinematics functions are well-implemented with proper JIT compilation for performance.
60-99
: LGTM! Clear site position manipulation.The site position manipulation functions are well-documented and correctly implemented.
136-269
: LGTM! Robust quaternion operations.The quaternion operations are thoroughly implemented with proper error handling and precision checks.
io.save_data_to_h5( | ||
config=cfg, file_path=fit_offsets_path, **fit_offsets_data.as_dict() | ||
) | ||
(fit_offsets_data, fit_offsets_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused tuple expression.
This line creates a tuple that is neither assigned to a variable nor returned.
- (fit_offsets_data, fit_offsets_path)
🧰 Tools
🪛 Ruff (0.8.2)
89-89: Found useless expression. Either assign it to a variable or remove it.
(B018)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
stac_mjx/io.py (3)
24-47
: Enhance dataclass attribute documentation.While the dataclasses are well-structured, the attribute documentation could be more descriptive to improve maintainability. Consider adding more detailed descriptions for each field, including:
- Valid value ranges or constraints
- Units of measurement where applicable
- Dependencies between fields
- Examples of valid values
Example for
ModelConfig
:@dataclass class ModelConfig: """Configuration for body model. Attributes: MJCF_PATH: Path to the MuJoCo XML model file (e.g., "models/mouse.xml") FTOL: Tolerance threshold for optimization convergence (e.g., 1e-6) ROOT_FTOL: Tolerance threshold for root optimization (e.g., 1e-6) LIMB_FTOL: Tolerance threshold for limb optimization (e.g., 1e-6) N_ITERS: Number of iterations for STAC algorithm (e.g., 100) ... """Also applies to: 49-56, 58-72, 74-80, 82-104
341-358
: Use canonical HDF5 data loading pattern.Consider using the canonical way to load data from HDF5 files for better type preservation and handling of scalars:
- kp_data = f["kp_data"][()] - marker_sites = f["marker_sites"][()] - offsets = f["offsets"][()] - qpos = f["qpos"][()] - qvel = f["qvel"][()] - xpos = f["xpos"][()] - xquat = f["xquat"][()] + with h5py.File(file_path, "r") as f: + kp_data = f["kp_data"][()] + marker_sites = f["marker_sites"][()] + offsets = f["offsets"][()] + qpos = f["qpos"][()] + qvel = f["qvel"][()] + xpos = f["xpos"][()] + xquat = f["xquat"][()]Reference: https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
323-329
: Consider optimizing HDF5 compression settings.The current implementation uses gzip compression without specifying the compression level. Consider tuning the compression settings based on your data characteristics:
- f.create_dataset("kp_data", data=kp_data, compression="gzip") - f.create_dataset("marker_sites", data=marker_sites, compression="gzip") - f.create_dataset("offsets", data=offsets, compression="gzip") - f.create_dataset("qpos", data=qpos, compression="gzip") - f.create_dataset("qvel", data=qvel, compression="gzip") - f.create_dataset("xpos", data=xpos, compression="gzip") - f.create_dataset("xquat", data=xquat, compression="gzip") + f.create_dataset("kp_data", data=kp_data, compression="gzip", compression_opts=4) + f.create_dataset("marker_sites", data=marker_sites, compression="gzip", compression_opts=4) + f.create_dataset("offsets", data=offsets, compression="gzip", compression_opts=4) + f.create_dataset("qpos", data=qpos, compression="gzip", compression_opts=4) + f.create_dataset("qvel", data=qvel, compression="gzip", compression_opts=4) + f.create_dataset("xpos", data=xpos, compression="gzip", compression_opts=4) + f.create_dataset("xquat", data=xquat, compression="gzip", compression_opts=4)The compression level (1-9) trades off compression ratio vs. speed:
- Lower values (1-4): Faster compression, larger file size
- Higher values (5-9): Better compression, slower performance
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
stac_mjx/io.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Tests (ubuntu-latest, Python 3.11)
- GitHub Check: Lint
🔇 Additional comments (1)
stac_mjx/io.py (1)
Line range hint
106-268
: LGTM! Well-structured data loading function.The renamed
load_mocap
function is well-documented, has proper error handling, and follows good practices. The docstring clearly describes the function's behavior, arguments, return values, and potential errors.
Summary by CodeRabbit
Release Notes
Configuration Changes
n_frames_per_clip
across multiple configuration files.infer_qvels
configuration option to control joint velocity inference.N_FRAMES_PER_CLIP
from various model configurations.Data Handling
load_data
toload_mocap
..p
) to HDF5 (.h5
) file formats for data storage.Performance Improvements
utils
module.Visualization
This release focuses on improving data handling, configuration management, and overall code organization.