Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save output data as h5 #84

Merged
merged 14 commits into from
Feb 6, 2025
Merged

Save output data as h5 #84

merged 14 commits into from
Feb 6, 2025

Conversation

charles-zhng
Copy link
Collaborator

@charles-zhng charles-zhng commented Jan 21, 2025

Summary by CodeRabbit

Release Notes

  • Configuration Changes

    • Introduced new configuration parameter n_frames_per_clip across multiple configuration files.
    • Added infer_qvels configuration option to control joint velocity inference.
    • Removed N_FRAMES_PER_CLIP from various model configurations.
  • Data Handling

    • Updated data loading method from load_data to load_mocap.
    • Transitioned from pickle (.p) to HDF5 (.h5) file formats for data storage.
    • Enhanced data serialization and configuration management.
  • Performance Improvements

    • Consolidated utility functions into a single utils module.
    • Improved velocity computation for kinematic data.
    • Streamlined configuration validation and loading processes.
  • Visualization

    • Updated visualization function to use new data loading mechanism.
    • Simplified rendering and data retrieval processes.

This release focuses on improving data handling, configuration management, and overall code organization.

Copy link
Contributor

coderabbitai bot commented Jan 21, 2025

Walkthrough

This pull request introduces significant changes to the Stac-MJX library, focusing on data handling, configuration management, and utility functions. The modifications span across multiple files, with key updates including the removal of N_FRAMES_PER_CLIP from various model configurations, the introduction of new configuration parameters such as n_frames_per_clip and infer_qvels in STAC configurations, and the replacement of the load_data function with load_mocap. Additionally, there is a transition from op_utils to utils for various utility functions, enhancing overall code organization and functionality.

Changes

File/Directory Change Summary
configs/config.yaml Removed commented-out defaults for FLY_MODEL.
configs/model/fly_tethered.yaml Updated MJCF_PATH and removed N_FRAMES_PER_CLIP.
configs/model/fly_treadmill.yaml Removed N_FRAMES_PER_CLIP.
configs/model/rodent.yaml Removed N_FRAMES_PER_CLIP.
configs/model/synth_data.yaml Removed N_FRAMES_PER_CLIP.
configs/stac/demo.yaml Updated file paths, increased n_fit_frames, added infer_qvels and n_frames_per_clip.
configs/stac/stac.yaml Added infer_qvels and n_frames_per_clip, removed commented-out FLY_MODEL settings.
configs/stac/stac_fly_tethered.yaml Added n_frames_per_clip.
configs/stac/stac_fly_treadmill.yaml Added n_frames_per_clip.
configs/stac/stac_mouse.yaml Added infer_qvels and n_frames_per_clip.
configs/stac/stac_synth_data.yaml Updated file paths, added infer_qvels and n_frames_per_clip.
demos/api_usage.ipynb Updated imports and function calls, removed autoreload commands.
demos/viz_usage.ipynb Updated data paths and function calls, removed autoreload commands.
run_stac.py Changed data loading from load_data to load_mocap.
stac_mjx/__init__.py Updated import paths for enable_xla_flags and load_data.
stac_mjx/compute_stac.py Updated import paths and variable names, removed comments.
stac_mjx/io.py Introduced new dataclasses, replaced load_data with load_mocap, added new saving/loading functions.
stac_mjx/io_dict_to_hdf5.py Removed file with utility functions for HDF5 handling.
stac_mjx/main.py Updated load_configs and run_stac functions for improved configurability.
stac_mjx/op_utils.py Removed file with utility functions for Mujoco.
stac_mjx/stac.py Updated initialization and method signatures, improved variable naming.
stac_mjx/stac_core.py Updated import paths and function calls to use utils.
stac_mjx/utils.py Added new utility functions for Mujoco operations.
stac_mjx/viz.py Simplified data loading process using io.load_stac_data.
tests/configs/model/test_mouse.yaml Removed N_FRAMES_PER_CLIP.
tests/configs/model/test_rodent.yaml Removed N_FRAMES_PER_CLIP.
tests/configs/model/test_rodent_label3d.yaml Removed N_FRAMES_PER_CLIP.
tests/configs/model/test_rodent_less_kp_names.yaml Removed N_FRAMES_PER_CLIP.
tests/configs/model/test_rodent_no_kp_names.yaml Removed N_FRAMES_PER_CLIP.
tests/configs/stac/test_stac.yaml Added infer_qvels and n_frames_per_clip.
tests/integration/test_model.py Updated import path for mjx_load.
tests/test_io.py Replaced load_data with load_mocap in tests.
tests/unit/test_controller.py Updated function call from load_data to load_mocap.
tests/unit/test_main.py Updated assertions in test_load_configs.
tests/unit/test_utils.py Replaced load_data with load_mocap in tests.

Possibly related PRs

  • add demo_viz.p #45: The changes in demos/viz_usage.ipynb involve updates to data paths and video playback settings, which may indirectly relate to the configuration changes in configs/config.yaml regarding how data is loaded and processed, particularly if the visualizations rely on the configurations defined in the YAML files.

Suggested reviewers

  • jf514

Poem

🐰 Hop, hop, through configs anew,
Data flows like a digital brew
Utilities dance, utils take flight
Mocap's journey, now sleek and bright
CodeRabbit's magic makes science delight! 🔬

✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 42.75362% with 158 lines in your changes missing coverage. Please review.

Project coverage is 44.26%. Comparing base (f3980e4) to head (52e3207).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
stac_mjx/utils.py 44.04% 47 Missing ⚠️
stac_mjx/io.py 63.54% 35 Missing ⚠️
stac_mjx/main.py 29.41% 24 Missing ⚠️
stac_mjx/stac.py 15.38% 22 Missing ⚠️
stac_mjx/compute_stac.py 11.76% 15 Missing ⚠️
stac_mjx/stac_core.py 10.00% 9 Missing ⚠️
stac_mjx/viz.py 14.28% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #84      +/-   ##
==========================================
+ Coverage   39.86%   44.26%   +4.40%     
==========================================
  Files          10        9       -1     
  Lines         582      689     +107     
==========================================
+ Hits          232      305      +73     
- Misses        350      384      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (3)
stac_mjx/io.py (1)

Line range hint 113-115: Update error message to include '.h5' as a supported file extension

The error message currently indicates that only '.nwb' or '.mat' files are supported, but the code also supports '.h5' files. Please update the error message to reflect this.

Apply this diff to fix the error message:

             raise ValueError(
-                "Unsupported file extension. Please provide a .nwb or .mat file."
+                "Unsupported file extension. Please provide a .nwb, .h5, or .mat file."
             )
stac_mjx/viz.py (1)

Line range hint 55-65: Breaking change: Function return signature has changed.

The function now returns a tuple (cfg, frames) instead of just frames. This change could break existing callers of the function.

Consider one of these approaches:

  1. Document the breaking change and update all callers
  2. Keep the original return signature and only return the frames
configs/stac/stac_fly_tethered.yaml (1)

Line range hint 1-2: Standardize output file formats to use .h5.

While ik_only_path uses .h5, fit_offsets_path still uses .p (pickle) format. Consider standardizing all outputs to use h5 format as per PR objective.

🧹 Nitpick comments (11)
stac_mjx/viz.py (1)

8-9: Clean up unused imports.

The following imports are not used in the code:

  • DictConfig from omegaconf
  • Dict from typing
-from omegaconf import DictConfig
-from typing import Union, Dict
+from typing import Union
🧰 Tools
🪛 Ruff (0.8.2)

8-8: omegaconf.DictConfig imported but unused

Remove unused import: omegaconf.DictConfig

(F401)


9-9: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

stac_mjx/main.py (2)

6-8: Remove unused imports.

The following imports are not used in the code:

  • pickle
  • logging
-import pickle
-import logging
🧰 Tools
🪛 Ruff (0.8.2)

6-6: pickle imported but unused

Remove unused import: pickle

(F401)


8-8: logging imported but unused

Remove unused import: logging

(F401)


89-89: Remove useless expression.

The tuple expression on line 89 is not being used.

-        (fit_offsets_data, fit_offsets_path)
🧰 Tools
🪛 Ruff (0.8.2)

89-89: Found useless expression. Either assign it to a variable or remove it.

(B018)

stac_mjx/compute_stac.py (1)

5-5: Remove unused import.

The numpy import is not used in the code.

-import numpy as np
🧰 Tools
🪛 Ruff (0.8.2)

5-5: numpy imported but unused

Remove unused import: numpy

(F401)

stac_mjx/utils.py (1)

199-209: Uncomment and fix the error handling in _clip_within_precision.

The error handling is commented out due to JIT compilation issues. Consider using jax.debug.print for debugging or implement a different error handling strategy.

Apply this diff to improve error handling:

-    # This is raising an error when jitted
-    # def _raise_if_not_in_precision():
-    #     if (number < low - precision).any() or (number > high + precision).any():
-    #         raise ValueError(
-    #             "Input {:.12f} not inside range [{:.12f}, {:.12f}] with precision {}".format(
-    #                 number, low, high, precision
-    #             )
-    #         )
-
-    # jax.debug.callback(_raise_if_not_in_precision)
+    # Use jax.debug.print for debugging in JIT-compiled functions
+    def _check_precision():
+        is_below = (number < low - precision).any()
+        is_above = (number > high + precision).any()
+        if is_below or is_above:
+            jax.debug.print(
+                "Warning: Input {x} not inside range [{l}, {h}] with precision {p}",
+                x=number,
+                l=low,
+                h=high,
+                p=precision
+            )
+    jax.lax.cond(
+        (number < low - precision).any() | (number > high + precision).any(),
+        lambda _: _check_precision(),
+        lambda _: None,
+        operand=None
+    )
stac_mjx/stac.py (1)

200-207: Improve error stats type handling.

The function now accepts a list but uses numpy array operations. Consider using JAX's numpy (jp) for consistency with the rest of the codebase.

Apply this diff to use JAX's numpy:

-        flattened_errors = np.array(errors).reshape(-1)
-        mean = np.mean(flattened_errors)
-        std = np.std(flattened_errors)
+        flattened_errors = jp.array(errors).reshape(-1)
+        mean = jp.mean(flattened_errors)
+        std = jp.std(flattened_errors)
tests/configs/stac/test_stac.yaml (1)

8-9: Fix trailing space and consider adding parameter documentation.

The new parameters look good for testing purposes, but there are a few minor improvements to make:

  1. Remove the trailing space after infer_qvels: False
  2. Consider adding comments explaining these parameters (similar to other config files)
-infer_qvels: False 
+infer_qvels: False  # Infer qvels from stac output
 n_frames_per_clip: 1  # Number of frames per clip for processing
🧰 Tools
🪛 yamllint (1.35.1)

[error] 8-8: trailing spaces

(trailing-spaces)

configs/stac/stac.yaml (1)

8-9: Add parameter documentation for consistency.

The parameter values look good, but for consistency with other config files (e.g., demo.yaml), consider adding explanatory comments.

-infer_qvels: True 
+infer_qvels: True  # Infer qvels from stac output
-n_frames_per_clip: 250
+n_frames_per_clip: 250  # Number of frames per clip for processing
🧰 Tools
🪛 yamllint (1.35.1)

[error] 8-8: trailing spaces

(trailing-spaces)

configs/stac/stac_mouse.yaml (1)

8-9: Add parameter documentation and fix trailing space.

The parameters look appropriate for mouse data processing, but maintain documentation consistency with other config files.

-infer_qvels: True 
+infer_qvels: True  # Infer qvels from stac output
-n_frames_per_clip: 360
+n_frames_per_clip: 360  # Number of frames per clip for mouse data
🧰 Tools
🪛 yamllint (1.35.1)

[error] 8-8: trailing spaces

(trailing-spaces)

configs/stac/demo.yaml (1)

1-1: Consider creating a schema or template for configuration files.

To maintain consistency across configuration files, consider:

  1. Creating a schema that documents all available parameters and their purposes
  2. Using a template that includes standardized comments for common parameters
  3. Adding validation for parameter values (e.g., ensuring n_frames_per_clip > 0)

This would help maintain documentation quality and prevent configuration drift across different use cases.

configs/stac/stac_fly_treadmill.yaml (1)

9-10: Document the purpose of n_frames_per_clip parameter.

This parameter seems to have been moved from model config. Consider adding a comment explaining its purpose and relationship with n_fit_frames.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f3980e4 and 2075ed4.

⛔ Files ignored due to path filters (1)
  • videos/direct_render.mp4 is excluded by !**/*.mp4
📒 Files selected for processing (35)
  • configs/config.yaml (0 hunks)
  • configs/model/fly_tethered.yaml (0 hunks)
  • configs/model/fly_treadmill.yaml (0 hunks)
  • configs/model/rodent.yaml (0 hunks)
  • configs/model/synth_data.yaml (0 hunks)
  • configs/stac/demo.yaml (1 hunks)
  • configs/stac/stac.yaml (1 hunks)
  • configs/stac/stac_fly_tethered.yaml (1 hunks)
  • configs/stac/stac_fly_treadmill.yaml (1 hunks)
  • configs/stac/stac_mouse.yaml (1 hunks)
  • configs/stac/stac_synth_data.yaml (1 hunks)
  • demos/api_usage.ipynb (7 hunks)
  • demos/viz_usage.ipynb (4 hunks)
  • run_stac.py (1 hunks)
  • stac_mjx/__init__.py (1 hunks)
  • stac_mjx/compute_stac.py (9 hunks)
  • stac_mjx/io.py (2 hunks)
  • stac_mjx/io_dict_to_hdf5.py (0 hunks)
  • stac_mjx/main.py (4 hunks)
  • stac_mjx/op_utils.py (0 hunks)
  • stac_mjx/stac.py (14 hunks)
  • stac_mjx/stac_core.py (5 hunks)
  • stac_mjx/utils.py (1 hunks)
  • stac_mjx/viz.py (2 hunks)
  • tests/configs/model/test_mouse.yaml (0 hunks)
  • tests/configs/model/test_rodent.yaml (0 hunks)
  • tests/configs/model/test_rodent_label3d.yaml (0 hunks)
  • tests/configs/model/test_rodent_less_kp_names.yaml (0 hunks)
  • tests/configs/model/test_rodent_no_kp_names.yaml (0 hunks)
  • tests/configs/stac/test_stac.yaml (1 hunks)
  • tests/integration/test_model.py (1 hunks)
  • tests/test_io.py (6 hunks)
  • tests/unit/test_controller.py (1 hunks)
  • tests/unit/test_main.py (1 hunks)
  • tests/unit/test_utils.py (6 hunks)
💤 Files with no reviewable changes (12)
  • tests/configs/model/test_rodent.yaml
  • configs/model/fly_treadmill.yaml
  • tests/configs/model/test_rodent_label3d.yaml
  • tests/configs/model/test_rodent_no_kp_names.yaml
  • tests/configs/model/test_rodent_less_kp_names.yaml
  • configs/model/synth_data.yaml
  • configs/config.yaml
  • configs/model/rodent.yaml
  • tests/configs/model/test_mouse.yaml
  • configs/model/fly_tethered.yaml
  • stac_mjx/io_dict_to_hdf5.py
  • stac_mjx/op_utils.py
✅ Files skipped from review due to trivial changes (1)
  • tests/integration/test_model.py
🧰 Additional context used
🪛 yamllint (1.35.1)
configs/stac/stac.yaml

[error] 8-8: trailing spaces

(trailing-spaces)

configs/stac/stac_mouse.yaml

[error] 8-8: trailing spaces

(trailing-spaces)

tests/configs/stac/test_stac.yaml

[error] 8-8: trailing spaces

(trailing-spaces)

🪛 Ruff (0.8.2)
stac_mjx/__init__.py

3-3: stac_mjx.utils.enable_xla_flags imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


4-4: stac_mjx.io.load_mocap imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

stac_mjx/viz.py

8-8: omegaconf.DictConfig imported but unused

Remove unused import: omegaconf.DictConfig

(F401)


9-9: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

stac_mjx/main.py

6-6: pickle imported but unused

Remove unused import: pickle

(F401)


8-8: logging imported but unused

Remove unused import: logging

(F401)


89-89: Found useless expression. Either assign it to a variable or remove it.

(B018)

stac_mjx/compute_stac.py

5-5: numpy imported but unused

Remove unused import: numpy

(F401)

stac_mjx/utils.py

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: Tests (ubuntu-latest, Python 3.11)
  • GitHub Check: Lint
🔇 Additional comments (33)
stac_mjx/__init__.py (1)

3-4: Imports updated correctly to reflect module restructuring

The import statements have been correctly updated to reflect the restructuring of utility functions and data loading methods.

🧰 Tools
🪛 Ruff (0.8.2)

3-3: stac_mjx.utils.enable_xla_flags imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)


4-4: stac_mjx.io.load_mocap imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

tests/unit/test_main.py (1)

22-23: Test assertions updated to match new configuration parameters

The test assertions have been appropriately updated to align with the new configuration parameters in the Config dataclass, ensuring the test remains valid.

run_stac.py (1)

11-11: LGTM! Function update aligns with codebase changes.

The change from load_data to load_mocap is consistent with the broader refactoring across the codebase.

tests/unit/test_controller.py (1)

14-14: LGTM! Test updated to match new API.

The test has been correctly updated to use load_mocap while maintaining the same test logic and assertions.

stac_mjx/viz.py (1)

41-46: LGTM! Improved data loading implementation.

The new implementation using io.load_stac_data centralizes data loading logic and provides better organization.

tests/unit/test_utils.py (1)

Line range hint 34-94: LGTM! Comprehensive test coverage maintained.

All test cases have been consistently updated to use load_mocap while maintaining coverage for:

  • Different file types (.nwb, .mat, .h5)
  • Error cases (missing/insufficient keypoint names)
  • Various configurations
tests/test_io.py (2)

34-34: LGTM! Function call updated to use new load_mocap method.

The change from load_data to load_mocap is consistent with the refactoring across the codebase.


44-44: Consistent update of load_data to load_mocap across all test functions.

All test functions have been updated to use the new load_mocap method while maintaining the same test logic and assertions.

Also applies to: 56-56, 69-69, 80-80, 94-94

stac_mjx/main.py (3)

Line range hint 18-37: LGTM! Enhanced configuration loading with validation.

The changes improve configuration management by:

  1. Adding config_name parameter for flexibility
  2. Adding structured config validation
  3. Adding user feedback through print statements

72-78: LGTM! Added vectorized velocity computation.

Good use of JAX's vmap for efficient batch processing of kinematic data.


86-88: LGTM! Implemented h5 data saving.

The implementation uses a dedicated method save_data_to_h5 for consistent data storage.

Also applies to: 120-120

stac_mjx/stac_core.py (2)

13-13: LGTM! Consistent transition from op_utils to utils module.

All utility function calls have been updated to use the consolidated utils module while maintaining the same functionality.

Also applies to: 54-54, 60-61, 64-64, 104-104, 150-157


199-214: LGTM! Improved docstring clarity.

The docstring for m_opt has been updated with clearer descriptions and better-formatted argument documentation.

stac_mjx/compute_stac.py (2)

70-71: LGTM! Consistent transition from op_utils to utils module.

All utility function calls have been updated to use the consolidated utils module while maintaining the same functionality.

Also applies to: 96-97, 144-144, 165-167, 170-170


201-204: LGTM! Improved variable naming and added quaternion storage.

Changes improve code clarity through:

  1. More descriptive variable names (qposes, xposes, marker_sites)
  2. Added storage for quaternion data (xquats)

Also applies to: 260-263, 271-274

stac_mjx/utils.py (5)

14-21: LGTM! Well-documented GPU optimization.

The function enables XLA flags for optimized GPU performance with clear documentation.


23-30: LGTM! Clear model loading implementation.

The function provides a clean interface for loading Mujoco models into MJX with proper documentation.


32-58: LGTM! Efficient JIT-compiled kinematics functions.

The functions are well-documented and use JAX's JIT compilation for improved performance.


60-99: LGTM! Comprehensive site position handling.

The functions provide a complete set of operations for getting and setting site positions with proper documentation.


272-313: LGTM! Well-structured velocity computation.

The function handles both free and constrained joints with proper velocity limits and clear documentation.

demos/api_usage.ipynb (3)

54-54: LGTM! Updated function call.

The function call has been updated from load_data to load_mocap to reflect the API changes.


269-270: LGTM! Clear output messages.

The output messages clearly indicate the data saving path and skipped operations.


293-293: LGTM! Updated kernel display name.

The kernel display name has been updated to "stac-mjx-env" for consistency.

stac_mjx/stac.py (3)

15-15: LGTM! Consolidated imports.

The imports have been reorganized to use utils instead of op_utils.


111-111: LGTM! Added freejoint flag.

The _freejoint flag is correctly initialized based on the model's joint type.


433-442: LGTM! Improved data packaging.

The data packaging now uses descriptive variable names and returns a structured StacData object.

demos/viz_usage.ipynb (3)

78-78: LGTM! Updated data format.

The data path has been updated to use the .h5 format instead of .p.


83-89: LGTM! Improved function call formatting.

The viz_stac function call has been formatted for better readability with one parameter per line.


105-105: LGTM! Updated kernel display name.

The kernel display name has been updated to "stac-mjx-env" for consistency.

configs/stac/stac_synth_data.yaml (1)

1-2: Changes look good and well documented!

The updates align well with the PR objective of switching to h5 format, and the parameters are appropriately documented.

Also applies to: 7-9

configs/stac/demo.yaml (1)

1-2: Well documented changes with clear explanations!

The changes look good and include helpful comments explaining the parameters. Consider reusing these well-written comments across other config files for consistency.

Also applies to: 5-5, 8-9

configs/stac/stac_fly_treadmill.yaml (1)

Line range hint 1-2: Consider updating output paths to use .h5 extension.

The PR objective is to save output data as h5, but the output paths still use pickle format (.p extension). Consider updating these paths to use .h5 extension for consistency.

configs/stac/stac_fly_tethered.yaml (1)

12-13: Document n_frames_per_clip and verify value consistency.

  1. Consider adding a comment explaining the purpose of this parameter and its relationship with n_fit_frames.
  2. The value (300) differs from the treadmill config (581). Please verify if this difference is intentional.
✅ Verification successful

Values are intentionally different, but documentation needed

The different n_frames_per_clip values across configs are intentional, as this parameter controls batch processing size for parallel computation. Suggested documentation for configs/stac/stac_fly_tethered.yaml:

n_frames_per_clip: 300  # Number of frames to process in parallel. Must divide evenly into total mocap frames.

Note: This parameter is distinct from n_fit_frames (601), which is used only for the initial offset fitting phase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check n_frames_per_clip values across configs
echo "Checking n_frames_per_clip values across configs:"
rg -g '*.yaml' 'n_frames_per_clip:'

Length of output: 545


Script:

#!/bin/bash
# Search for documentation and usage of n_frames_per_clip
echo "Searching for documentation and usage of n_frames_per_clip:"
rg -g '!*.yaml' 'n_frames_per_clip' -B 2 -A 2

echo -e "\nSearching for n_fit_frames:"
rg 'n_fit_frames' -B 2 -A 2

Length of output: 4103


def viz_stac(
data_path: Union[Path, str],
cfg: DictConfig,
# cfg: DictConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove commented parameter instead of commenting it out.

Clean up the function signature by removing the commented-out parameter entirely, as it's no longer used.

-    # cfg: DictConfig,

Comment on lines +10 to +11
import numpy as np
from stac_mjx import io
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused imports.

The following imports are not used in this file:

  • numpy
  • stac_mjx.io

Apply this diff to remove the unused imports:

-import numpy as np
-from stac_mjx import io
🧰 Tools
🪛 Ruff (0.8.2)

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)

Comment on lines +101 to +134
def make_qs(q0, qs_to_opt, q):
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt.

Args:
q0 (jax.Array): initial joint angles
qs_to_opt (jax.Array): joint angles that were optimized
q (jax.Array): new joint angles

Returns:
jp.Array: resulting set of joint angles
"""
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q))


def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q):
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics.

Args:
mjx_model (mjx.Model):
mjx_data (mjx.Data):
q (jax.Array): new joint angles

Returns:
mjx.Data: resulting mjx Data
"""
if q is None:
print("optimization failed, continuing")

else:
mjx_data = mjx_data.replace(qpos=q)
mjx_data = kinematics(mjx_model, mjx_data)

return mjx_data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for None input in replace_qs.

The function handles None input by printing a message and continuing, which could mask errors. Consider raising an exception or returning early with a meaningful error message.

Apply this diff to improve error handling:

 def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q):
     """Replace joint angles in mjx.Data with new ones and performs forward kinematics.
 
     Args:
         mjx_model (mjx.Model):
         mjx_data (mjx.Data):
         q (jax.Array): new joint angles
 
     Returns:
         mjx.Data: resulting mjx Data
+
+    Raises:
+        ValueError: If q is None, indicating optimization failure
     """
     if q is None:
-        print("optimization failed, continuing")
+        raise ValueError("Optimization failed: received None for joint angles")
 
     else:
         mjx_data = mjx_data.replace(qpos=q)
         mjx_data = kinematics(mjx_model, mjx_data)
 
     return mjx_data
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def make_qs(q0, qs_to_opt, q):
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt.
Args:
q0 (jax.Array): initial joint angles
qs_to_opt (jax.Array): joint angles that were optimized
q (jax.Array): new joint angles
Returns:
jp.Array: resulting set of joint angles
"""
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q))
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q):
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics.
Args:
mjx_model (mjx.Model):
mjx_data (mjx.Data):
q (jax.Array): new joint angles
Returns:
mjx.Data: resulting mjx Data
"""
if q is None:
print("optimization failed, continuing")
else:
mjx_data = mjx_data.replace(qpos=q)
mjx_data = kinematics(mjx_model, mjx_data)
return mjx_data
def make_qs(q0, qs_to_opt, q):
"""Create new set of qs combining initial and new qs for part optimization based on qs_to_opt.
Args:
q0 (jax.Array): initial joint angles
qs_to_opt (jax.Array): joint angles that were optimized
q (jax.Array): new joint angles
Returns:
jp.Array: resulting set of joint angles
"""
return jp.copy((1 - qs_to_opt) * q0 + qs_to_opt * jp.copy(q))
def replace_qs(mjx_model: mjx.Model, mjx_data: mjx.Data, q):
"""Replace joint angles in mjx.Data with new ones and performs forward kinematics.
Args:
mjx_model (mjx.Model):
mjx_data (mjx.Data):
q (jax.Array): new joint angles
Returns:
mjx.Data: resulting mjx Data
Raises:
ValueError: If q is None, indicating optimization failure
"""
if q is None:
raise ValueError("Optimization failed: received None for joint angles")
else:
mjx_data = mjx_data.replace(qpos=q)
mjx_data = kinematics(mjx_model, mjx_data)
return mjx_data

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
stac_mjx/utils.py (1)

10-11: Remove unused imports.

The following imports are not used in this file:

  • numpy
  • stac_mjx.io

Apply this diff to remove the unused imports:

-import numpy as np
-from stac_mjx import io
🧰 Tools
🪛 Ruff (0.8.2)

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)

stac_mjx/io.py (2)

265-326: Add error handling for HDF5 operations.

While the HDF5 saving functions are well-structured, they could benefit from error handling for common issues like:

  • File permission errors
  • Disk space issues
  • Invalid data types

Would you like me to provide an example implementation with proper error handling?


328-368: Add input validation for HDF5 loading.

Consider adding validation for:

  • File existence and format
  • Required dataset presence
  • Data shape consistency

Would you like me to provide an example implementation with proper validation?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2075ed4 and df031e1.

📒 Files selected for processing (4)
  • stac_mjx/io.py (2 hunks)
  • stac_mjx/stac.py (14 hunks)
  • stac_mjx/utils.py (1 hunks)
  • stac_mjx/viz.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • stac_mjx/stac.py
🧰 Additional context used
🪛 Ruff (0.8.2)
stac_mjx/viz.py

8-8: omegaconf.DictConfig imported but unused

Remove unused import: omegaconf.DictConfig

(F401)


9-9: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

stac_mjx/utils.py

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: Tests (ubuntu-latest, Python 3.11)
  • GitHub Check: Lint
🔇 Additional comments (4)
stac_mjx/viz.py (2)

14-14: Remove commented parameter instead of commenting it out.

Clean up the function signature by removing the commented-out parameter entirely, as it's no longer used.


40-44: LGTM! Good improvements to data handling.

The changes improve code organization by:

  • Centralizing data loading in io.load_stac_data.
  • Using structured data class for better data organization.
  • Providing configuration back to caller for potential use.

Also applies to: 54-63

stac_mjx/utils.py (1)

126-128: Add error handling for None input in replace_qs.

The function handles None input by printing a message and continuing, which could mask errors. Consider raising an exception or returning early with a meaningful error message.

stac_mjx/io.py (1)

23-96: Well-structured configuration and data management using dataclasses!

The dataclass hierarchy provides:

  • Clear organization of configuration parameters
  • Type safety through annotations
  • Easy serialization/deserialization


@dataclass
class ModelConfig:
"""Configuration for body model."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could document what the attributes mean in this and subsequent dataclass docstrings?

stac_mjx/io.py Outdated
Comment on lines 347 to 353
kp_data = np.array(f["kp_data"])
marker_sites = np.array(f["marker_sites"])
offsets = np.array(f["offsets"])
qpos = np.array(f["qpos"])
qvel = np.array(f["qvel"])
xpos = np.array(f["xpos"])
xquat = np.array(f["xquat"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Canonical way to load the datasets into memory is like:

with h5py.File(file_path, "r") as f:
    kp_data = f["kp_data"][()]

Ref: https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data

Numpy array casting has the same effect by forcing the read op, but doing it the way above maintains type and works with scalars and other primitives without coercing them into arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

stac_mjx/main.py Outdated

# Stop here if not doing ik only phase
if cfg.stac.skip_ik_only == 1:
logging.info("skipping ik_only()")
print("skipping ik_only()")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give the users a bit more of an informative error message? Maybe:

Suggested change
print("skipping ik_only()")
print("Skipping IK-only phase. Only offsets will be returned. To change this behavior, set cfg.stac.skip_ik_only to 0.")


Args:
qpos_trajectory (jp.ndarray): trajectory of qpos values T x ?
Note assumes has freejoint as the first 7 dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just have this as a kwarg with 7 as default? Will this change across body models?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first 7 dimensions is always the freejoint if it exists; the freejoint arg accounts for this so I'll use that when calling this function

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (7)
stac_mjx/main.py (1)

6-6: Remove unused imports.

The following imports are not used in the code:

  • pickle
  • logging
-import pickle
import time
-import logging
from omegaconf import DictConfig, OmegaConf

Also applies to: 8-8

🧰 Tools
🪛 Ruff (0.8.2)

6-6: pickle imported but unused

Remove unused import: pickle

(F401)

stac_mjx/viz.py (1)

8-9: Remove unused imports.

The following imports are not used in the code:

  • DictConfig from omegaconf
  • Dict from typing
-from omegaconf import DictConfig
-from typing import Union, Dict
+from typing import Union
🧰 Tools
🪛 Ruff (0.8.2)

8-8: omegaconf.DictConfig imported but unused

Remove unused import: omegaconf.DictConfig

(F401)


9-9: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

stac_mjx/io.py (3)

269-282: Consider error handling for HDF5 operations.

The save_dict_to_hdf5 function should handle potential HDF5 errors that may occur during group creation or attribute setting.

 def save_dict_to_hdf5(group, dictionary):
     """Save a dictionary to an HDF5 group.
 
     Args:
         group (h5py.Group): HDF5 group to save the dictionary to.
         dictionary (dict): Dictionary to save.
+
+    Raises:
+        h5py.H5Error: If there's an error during HDF5 operations.
     """
-    for key, value in dictionary.items():
-        if isinstance(value, dict):
-            subgroup = group.create_group(key)
-            save_dict_to_hdf5(subgroup, value)
-        else:
-            group.attrs[key] = value
+    try:
+        for key, value in dictionary.items():
+            if isinstance(value, dict):
+                subgroup = group.create_group(key)
+                save_dict_to_hdf5(subgroup, value)
+            else:
+                group.attrs[key] = value
+    except (TypeError, ValueError) as e:
+        raise ValueError(f"Failed to save dictionary to HDF5: {e}")

284-330: Consider adding data validation.

The save_data_to_h5 function should validate input arrays for shape consistency and non-null values before saving.

 def save_data_to_h5(...):
     """Save configuration and STAC data to an HDF5 file."""
+    # Validate input arrays
+    if not all(isinstance(arr, np.ndarray) for arr in [kp_data, marker_sites, offsets, qpos, xpos, xquat, qvel]):
+        raise ValueError("All data arrays must be numpy arrays")
+
+    # Validate array shapes
+    if len(kp_names) != kp_data.shape[1]:
+        raise ValueError(f"Mismatch between kp_names length ({len(kp_names)}) and kp_data shape ({kp_data.shape})")
+
     with h5py.File(file_path, "w") as f:
         # Save config as a YAML string
         config_yaml = OmegaConf.to_yaml(OmegaConf.structured(config))

332-372: Consider adding file existence check.

The load_stac_data function should verify file existence before attempting to load it.

 def load_stac_data(file_path) -> tuple[Config, StacData]:
     """Load configuration and STAC data from an HDF5 file."""
+    if not os.path.exists(file_path):
+        raise FileNotFoundError(f"HDF5 file not found: {file_path}")
+
     with h5py.File(file_path, "r") as f:
stac_mjx/utils.py (2)

281-281: Improve documentation clarity for freejoint dimensions.

Consider adding a constant like FREEJOINT_DIMS = 7 or enhancing the comment to explain why the first 7 dimensions correspond to the freejoint (3 for translation + 4 for quaternion rotation).


298-305: Consider vectorizing the quaternion operations.

The loop for computing qvel_gyro could be vectorized using vmap for better performance:

-        qvel_gyro = []
-        for t in range(qpos_trajectory.shape[0] - 1):
-            normed_diff = quat_diff(
-                qpos_trajectory[t, 3:7], qpos_trajectory[t + 1, 3:7]
-            )
-            normed_diff /= jp.linalg.norm(normed_diff)
-            angle = quat_to_axisangle(normed_diff)
-            qvel_gyro.append(angle / dt)
-        qvel_gyro = jp.stack(qvel_gyro)
+        def compute_gyro(t):
+            normed_diff = quat_diff(
+                qpos_trajectory[t, 3:7], qpos_trajectory[t + 1, 3:7]
+            )
+            normed_diff /= jp.linalg.norm(normed_diff)
+            return quat_to_axisangle(normed_diff) / dt
+        qvel_gyro = jax.vmap(compute_gyro)(jp.arange(qpos_trajectory.shape[0] - 1))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between df031e1 and bbfd6e6.

📒 Files selected for processing (4)
  • stac_mjx/io.py (2 hunks)
  • stac_mjx/main.py (4 hunks)
  • stac_mjx/utils.py (1 hunks)
  • stac_mjx/viz.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
stac_mjx/main.py

6-6: pickle imported but unused

Remove unused import: pickle

(F401)


8-8: logging imported but unused

Remove unused import: logging

(F401)


89-89: Found useless expression. Either assign it to a variable or remove it.

(B018)

stac_mjx/utils.py

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)

stac_mjx/viz.py

8-8: omegaconf.DictConfig imported but unused

Remove unused import: omegaconf.DictConfig

(F401)


9-9: typing.Dict imported but unused

Remove unused import: typing.Dict

(F401)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: Tests (ubuntu-latest, Python 3.11)
  • GitHub Check: Lint
🔇 Additional comments (14)
stac_mjx/main.py (5)

18-20: LGTM! Well-documented function signature update.

The addition of the config_name parameter with a default value maintains backward compatibility while adding flexibility.


32-36: LGTM! Enhanced configuration validation.

Good improvements:

  • Using structured configuration validation
  • Adding user feedback about config loading

72-78: LGTM! Efficient velocity computation using JAX's vmap.

Good optimization using vectorized operations for computing velocities.


97-99: Enhance error message with more information.

Consider providing more informative error messages to guide users.

-            "Skipping IK-only phase. To change this behavior, set cfg.stac.skip_ik_only to 0."
+            "Skipping IK-only phase. Only offsets will be returned. To change this behavior, set cfg.stac.skip_ik_only to 0."

86-88: LGTM! Consistent use of H5 format for data saving.

Good use of the new save_data_to_h5 function for consistent data storage.

Also applies to: 128-128

stac_mjx/viz.py (2)

40-44: LGTM! Clean data loading implementation.

The code now uses the centralized load_stac_data function to load both configuration and data, which is a cleaner approach than the previous implementation.


Line range hint 54-64: LGTM! Return statement updated to include configuration.

The return statement now includes the configuration object, which provides additional context alongside the rendering result.

stac_mjx/io.py (1)

23-104: LGTM! Well-documented dataclasses.

The dataclasses are well-structured and each attribute is properly documented with clear descriptions.

stac_mjx/utils.py (6)

10-11: Remove unused imports.

The imports of numpy and stac_mjx.io are not used in this file.

🧰 Tools
🪛 Ruff (0.8.2)

10-10: numpy imported but unused

Remove unused import: numpy

(F401)


11-11: stac_mjx.io imported but unused

Remove unused import: stac_mjx.io

(F401)


14-21: LGTM! Good GPU optimization.

The function correctly checks for GPU backend before enabling XLA flags for optimized performance.


115-134: Improve error handling in replace_qs.

The function should raise an exception for None input rather than printing a message and continuing.


23-58: LGTM! Well-structured kinematics functions.

The model loading and kinematics functions are well-implemented with proper JIT compilation for performance.


60-99: LGTM! Clear site position manipulation.

The site position manipulation functions are well-documented and correctly implemented.


136-269: LGTM! Robust quaternion operations.

The quaternion operations are thoroughly implemented with proper error handling and precision checks.

io.save_data_to_h5(
config=cfg, file_path=fit_offsets_path, **fit_offsets_data.as_dict()
)
(fit_offsets_data, fit_offsets_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused tuple expression.

This line creates a tuple that is neither assigned to a variable nor returned.

-        (fit_offsets_data, fit_offsets_path)
🧰 Tools
🪛 Ruff (0.8.2)

89-89: Found useless expression. Either assign it to a variable or remove it.

(B018)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
stac_mjx/io.py (3)

24-47: Enhance dataclass attribute documentation.

While the dataclasses are well-structured, the attribute documentation could be more descriptive to improve maintainability. Consider adding more detailed descriptions for each field, including:

  • Valid value ranges or constraints
  • Units of measurement where applicable
  • Dependencies between fields
  • Examples of valid values

Example for ModelConfig:

@dataclass
class ModelConfig:
    """Configuration for body model.
    
    Attributes:
        MJCF_PATH: Path to the MuJoCo XML model file (e.g., "models/mouse.xml")
        FTOL: Tolerance threshold for optimization convergence (e.g., 1e-6)
        ROOT_FTOL: Tolerance threshold for root optimization (e.g., 1e-6)
        LIMB_FTOL: Tolerance threshold for limb optimization (e.g., 1e-6)
        N_ITERS: Number of iterations for STAC algorithm (e.g., 100)
        ...
    """

Also applies to: 49-56, 58-72, 74-80, 82-104


341-358: Use canonical HDF5 data loading pattern.

Consider using the canonical way to load data from HDF5 files for better type preservation and handling of scalars:

-        kp_data = f["kp_data"][()]
-        marker_sites = f["marker_sites"][()]
-        offsets = f["offsets"][()]
-        qpos = f["qpos"][()]
-        qvel = f["qvel"][()]
-        xpos = f["xpos"][()]
-        xquat = f["xquat"][()]
+        with h5py.File(file_path, "r") as f:
+            kp_data = f["kp_data"][()]
+            marker_sites = f["marker_sites"][()]
+            offsets = f["offsets"][()]
+            qpos = f["qpos"][()]
+            qvel = f["qvel"][()]
+            xpos = f["xpos"][()]
+            xquat = f["xquat"][()]

Reference: https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data


323-329: Consider optimizing HDF5 compression settings.

The current implementation uses gzip compression without specifying the compression level. Consider tuning the compression settings based on your data characteristics:

-        f.create_dataset("kp_data", data=kp_data, compression="gzip")
-        f.create_dataset("marker_sites", data=marker_sites, compression="gzip")
-        f.create_dataset("offsets", data=offsets, compression="gzip")
-        f.create_dataset("qpos", data=qpos, compression="gzip")
-        f.create_dataset("qvel", data=qvel, compression="gzip")
-        f.create_dataset("xpos", data=xpos, compression="gzip")
-        f.create_dataset("xquat", data=xquat, compression="gzip")
+        f.create_dataset("kp_data", data=kp_data, compression="gzip", compression_opts=4)
+        f.create_dataset("marker_sites", data=marker_sites, compression="gzip", compression_opts=4)
+        f.create_dataset("offsets", data=offsets, compression="gzip", compression_opts=4)
+        f.create_dataset("qpos", data=qpos, compression="gzip", compression_opts=4)
+        f.create_dataset("qvel", data=qvel, compression="gzip", compression_opts=4)
+        f.create_dataset("xpos", data=xpos, compression="gzip", compression_opts=4)
+        f.create_dataset("xquat", data=xquat, compression="gzip", compression_opts=4)

The compression level (1-9) trades off compression ratio vs. speed:

  • Lower values (1-4): Faster compression, larger file size
  • Higher values (5-9): Better compression, slower performance
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bbfd6e6 and 52e3207.

📒 Files selected for processing (1)
  • stac_mjx/io.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: Tests (ubuntu-latest, Python 3.11)
  • GitHub Check: Lint
🔇 Additional comments (1)
stac_mjx/io.py (1)

Line range hint 106-268: LGTM! Well-structured data loading function.

The renamed load_mocap function is well-documented, has proper error handling, and follows good practices. The docstring clearly describes the function's behavior, arguments, return values, and potential errors.

@charles-zhng charles-zhng merged commit 5d99b42 into main Feb 6, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants