Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic data smoke test. #75

Merged
merged 29 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bbde0a4
Model runs + draws in notebook, no data output
jf514 Oct 8, 2024
4661907
Configs etc - model not yet working.
jf514 Oct 25, 2024
d5cdd97
Disable energy
jf514 Oct 25, 2024
9baa1c7
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
a6b983f
Smoke test (#74)
jf514 Oct 25, 2024
7ed9f25
IT'S WORKING
jf514 Oct 25, 2024
d4145c5
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
8248778
Configs etc - model not yet working.
jf514 Oct 25, 2024
be2f1aa
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
935c922
Configs etc - model not yet working.
jf514 Oct 25, 2024
2e6e01c
Merge remote-tracking branch 'origin/main' into synthetic-data
jf514 Oct 26, 2024
45774d9
Fix weird merge.
jf514 Oct 26, 2024
e9c1a4f
Clean up synth_model config file.
jf514 Oct 28, 2024
f14910b
Remove TIME_BINS (which was a merge accident.)
jf514 Oct 28, 2024
f6df7e3
Fix smoke test.
jf514 Oct 28, 2024
7b16b5a
Fix smoke test.
jf514 Oct 28, 2024
b75b354
Clean up.
jf514 Oct 28, 2024
6f7da79
Fixed root optimization, but still some debug code.
jf514 Oct 28, 2024
5433344
Add root_kp_index
jf514 Oct 29, 2024
2405815
Forgot model yaml.
jf514 Oct 29, 2024
9a2a599
Reset rodent configs, enable synth config.
jf514 Oct 29, 2024
86d333e
Add synth_data smoke test.
jf514 Oct 29, 2024
84793e9
Missed data file.
jf514 Oct 29, 2024
82c1887
Clean up.
jf514 Oct 30, 2024
d9fe782
Add root opt keypoint to model configs + clean up.
jf514 Oct 30, 2024
418a1e9
Clean up.
jf514 Oct 30, 2024
19a30e0
CR feedback.
jf514 Nov 1, 2024
dc26558
Add synth data generation program.
jf514 Nov 1, 2024
dc881c3
Add comments.
jf514 Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
verbose: false
token: ${{ secrets.CODECOV_TOKEN }}

# Test probably delete this
# Smoke test. Shows end to end run with out crashing.
- name: Smoke Test
shell: bash -l {0}
run: python run_stac.py
run: python run_stac.py --config-name config_synth_data.yaml
594 changes: 594 additions & 0 deletions Mat-to-Nwb-Synth-Data.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions configs/config_synth_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- stac: stac_synth_data
- model: synth_data
- _self_

2 changes: 0 additions & 2 deletions configs/model/fly_tethered.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
4 changes: 2 additions & 2 deletions configs/model/fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ KEYPOINT_INITIAL_OFFSETS:
l3: 0. 0. 0.
r3: 0. 0. 0.

ROOT_OPTIMIZATION_KEYPOINT: head

TRUNK_OPTIMIZATION_KEYPOINTS:
- 'head'
- 'thorax'
Expand Down Expand Up @@ -101,5 +103,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
2 changes: 2 additions & 0 deletions configs/model/mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ KEYPOINT_INITIAL_OFFSETS:
Lisfranc_L: 0.0 0.0 0.0
MTP_R: 0.0 0.0 0.0

ROOT_OPTIMIZATION_KEYPOINT: Trunk

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Trunk"
- "HipL"
Expand Down
2 changes: 2 additions & 0 deletions configs/model/rodent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "SpineF"
- "SpineL"
Expand Down
58 changes: 58 additions & 0 deletions configs/model/synth_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

MJCF_PATH: 'models/synth_model.xml'
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 1

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 1

KP_NAMES:
- part_0

ROOT_OPTIMIZATION_KEYPOINT: part_0

# The model sites used to register the keypoints.
KEYPOINT_MODEL_PAIRS:
part_0: base

# The initial offsets for each keypoint in meters.
KEYPOINT_INITIAL_OFFSETS:
part_0: 0 0 0.01

TRUNK_OPTIMIZATION_KEYPOINTS:
- part_0

INDIVIDUAL_PART_OPTIMIZATION:
model_base: [base]

# Color to use for each keypoint when visualizing the results
KEYPOINT_COLOR_PAIRS:
part_0: 0 .5 1 1

# What is the size of the animal you'd like to register, relative to the model?
SCALE_FACTOR: 1

# Multiplier to put the mocap data into the same scale as the data. Eg, if the
# mocap data is known to be in millimeters and the model is in meters, this is
# .001
MOCAP_SCALE_FACTOR: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using this with M_REG_COEF.
SITES_TO_REGULARIZE:
- part_0

RENDER_FPS: 200

N_SAMPLE_FRAMES: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
jf514 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion configs/stac/stac_fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ik_only_path: "transform_treadmill.p"
# File is too large to commit
# DL from: https://datadryad.org/stash/dataset/doi:10.5061/dryad.mpg4f4r73
# Actual file: https://datadryad.org/stash/downloads/file_stream/3361804
data_path: "/tests/data/wt_berlin_linear_treadmill_dataset.csv"
data_path: "../tests/data/wt_berlin_linear_treadmill_dataset.csv"
jf514 marked this conversation as resolved.
Show resolved Hide resolved

n_fit_frames: 1800
skip_fit: False
Expand Down
12 changes: 12 additions & 0 deletions configs/stac/stac_synth_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fit_offsets_path: "synth_fit.p"
ik_only_path: "synth_ik_only.p"
data_path: "tests/data/test_synth_1_frames.nwb"

n_fit_frames: 1
skip_fit_offsets: False
skip_ik_only: False

mujoco:
solver: newton
iterations: 1
ls_iterations: 4
254 changes: 254 additions & 0 deletions demos/create_synth_data.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was the data saved from this notebook? And were there any offsets applied to the synthetic keypoint? I wanted to test whether it finds the ground truth offset when given different initial offsets in the model.yaml file, but it always returns the initial offset without any changes.

This notebook doesn't run as is; the rendering is cell 3 is different for me from the one shared, and the last cell throw an error: KeyError: "Invalid name '0'. Valid names: ['base', 'world']"

Copy link
Contributor Author

@jf514 jf514 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's just not ready for that. I changed to the title to reflect that it's just a smoke test... ie it runs with out crashing.

This PR isn't to provide that level of functionality, sadly. I've created #81 to track the next steps. Please feel free to add any specific requests to that.

As for the data generation, I created a single frame of fake data (not even collected from an actual Mujoco run) just to give an output.

Also, if you think there are any comments that need to be added to the code to make this clear, also lmk.

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions models/synth_model.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<mujoco>
<option timestep=".001">
</option>

<default>
<joint type="hinge" axis="0 -1 0"/>
<geom type="capsule" size=".02"/>
</default>

jf514 marked this conversation as resolved.
Show resolved Hide resolved
<worldbody>
<light pos="0 -.4 1"/>
<camera name="fixed" pos="0 -1 0" xyaxes="1 0 0 0 0 1"/>
<body name="base" pos="0 0 .2">
<joint type="free" name="root"/>
<geom fromto="0 0 0 0 0 -.25" rgba="1 1 0 1"/>
</body>
</worldbody>
</mujoco>
3 changes: 2 additions & 1 deletion stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def root_optimization(
mjx_model,
mjx_data,
kp_data: jp.ndarray,
root_kp_idx: int,
lb: jp.ndarray,
ub: jp.ndarray,
site_idxs: jp.ndarray,
Expand Down Expand Up @@ -50,7 +51,7 @@ def root_optimization(
# necessarily exactly so. The value of 3*18 is chosen for the
# rodent.xml, corresponding to the index of 'SpineL' keypoint.
# For the mouse model this should be 3*5, corresponding 'Trunk'
root_kp_idx = 3 * 18
# root_kp_idx = 3 * 18
# FLY_MODEL:
# root_kp_idx = 0
q0.at[:3].set(kp_data[frame, :][root_kp_idx : root_kp_idx + 3])
Expand Down
27 changes: 22 additions & 5 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,22 @@
self._mj_model.body(i).name for i in range(self._mj_model.nbody)
]

joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
if "ROOT_OPTIMIZATION_KEYPOINT" in self.cfg.model:
self._root_kp_idx = self._kp_names.index(
self.cfg.model.ROOT_OPTIMIZATION_KEYPOINT
)
else:
self._root_kp_idx = -1

Check warning on line 88 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L88

Added line #L88 was not covered by tests
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters
joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
self._lb, self._ub, self._part_names = _align_joint_dims(
self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names
)

self._indiv_parts = self.part_opt_setup()

# Generate boolean flags for keypoints included in trunk optimization.
self._trunk_kps = jp.array(
[n in self.cfg.model.TRUNK_OPTIMIZATION_KEYPOINTS for n in kp_names],
)
Expand All @@ -113,7 +120,7 @@
[any(part in name for part in parts) for name in self._part_names]
)

if self.cfg.model.INDIVIDUAL_PART_OPTIMIZATION is None:
if "INDIVIDUAL_PART_OPTIMIZATION" not in self.cfg.model:
indiv_parts = []
else:
indiv_parts = jp.array(
Expand Down Expand Up @@ -224,11 +231,16 @@

# Begin optimization steps
# Skip root optimization if model is fixed (no free joint at root)
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense 👍

print(

Check warning on line 235 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L234-L235

Added lines #L234 - L235 were not covered by tests
"ROOT_OPTIMIZATION_KEYPOINT not specified, skipping Root Optimization."
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:

Check warning on line 238 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L238

Added line #L238 was not covered by tests
jf514 marked this conversation as resolved.
Show resolved Hide resolved
mjx_data = compute_stac.root_optimization(
mjx_model,
mjx_data,
kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down Expand Up @@ -339,15 +351,20 @@
)

# q_phase - root
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
print(

Check warning on line 355 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L354-L355

Added lines #L354 - L355 were not covered by tests
"Missing or invalid ROOT_OPTIMIZATION_KEYPOINT, skipping root_optimization()"
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:

Check warning on line 358 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L358

Added line #L358 was not covered by tests
jf514 marked this conversation as resolved.
Show resolved Hide resolved
vmap_root_opt = jax.vmap(
compute_stac.root_optimization,
in_axes=(0, 0, 0, None, None, None, None),
in_axes=(0, 0, 0, None, None, None, None, None),
jf514 marked this conversation as resolved.
Show resolved Hide resolved
)
mjx_data = vmap_root_opt(
mjx_model,
mjx_data,
batched_kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_mouse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ KEYPOINT_INITIAL_OFFSETS:
Lisfranc_L: 0.0 0.0 0.0
MTP_R: 0.0 0.0 0.0

ROOT_OPTIMIZATION_KEYPOINT: Trunk

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Trunk"
- "HipL"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

jf514 marked this conversation as resolved.
Show resolved Hide resolved
TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
3 changes: 2 additions & 1 deletion tests/configs/model/test_rodent_label3d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ N_FRAMES_PER_CLIP: 360
# presumed to be derived from label3d:
KP_NAMES_LABEL3D_PATH: "tests/data/rat23.mat"


# The model sites used to register the keypoints.
KEYPOINT_MODEL_PAIRS:
AnkleL: lower_leg_L
Expand Down Expand Up @@ -61,6 +60,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent_less_kp_names.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

jf514 marked this conversation as resolved.
Show resolved Hide resolved
TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
2 changes: 2 additions & 0 deletions tests/configs/model/test_rodent_no_kp_names.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
Expand Down
Binary file added tests/data/test_synth_1_frames.nwb
Binary file not shown.
Loading