diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8bdd026..d0a5bd9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,7 +101,7 @@ jobs: verbose: false token: ${{ secrets.CODECOV_TOKEN }} - # Test probably delete this + # Smoke test. Shows end to end run with out crashing. - name: Smoke Test shell: bash -l {0} - run: python run_stac.py \ No newline at end of file + run: python run_stac.py stac=stac_synth_data model=synth_data \ No newline at end of file diff --git a/Mat-to-Nwb-Synth-Data.ipynb b/Mat-to-Nwb-Synth-Data.ipynb new file mode 100644 index 0000000..f268184 --- /dev/null +++ b/Mat-to-Nwb-Synth-Data.ipynb @@ -0,0 +1,594 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mJSWVDRbCkMD", + "outputId": "f60eafc3-84b5-4c37-e711-bf2378115d73" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: ndx-pose in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (0.1.1)\n", + "Requirement already satisfied: pynwb<3,>=1.5.0 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from ndx-pose) (2.8.2)\n", + "Requirement already satisfied: hdmf<4,>=2.5.6 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from ndx-pose) (3.14.5)\n", + "Requirement already satisfied: h5py>=2.10 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (3.11.0)\n", + "Requirement already satisfied: jsonschema>=2.6.0 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (4.23.0)\n", + "Requirement already satisfied: numpy>=1.18 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.0.5 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (2.2.3)\n", + "Requirement already satisfied: ruamel-yaml>=0.16 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (0.18.6)\n", + "Requirement already satisfied: scipy>=1.4 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from hdmf<4,>=2.5.6->ndx-pose) (1.14.1)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from pynwb<3,>=1.5.0->ndx-pose) (2.9.0)\n", + "Requirement already satisfied: attrs>=22.2.0 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from jsonschema>=2.6.0->hdmf<4,>=2.5.6->ndx-pose) (24.2.0)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from jsonschema>=2.6.0->hdmf<4,>=2.5.6->ndx-pose) (2024.10.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from jsonschema>=2.6.0->hdmf<4,>=2.5.6->ndx-pose) (0.35.1)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from jsonschema>=2.6.0->hdmf<4,>=2.5.6->ndx-pose) (0.20.0)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from pandas>=1.0.5->hdmf<4,>=2.5.6->ndx-pose) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from pandas>=1.0.5->hdmf<4,>=2.5.6->ndx-pose) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from python-dateutil>=2.7.3->pynwb<3,>=1.5.0->ndx-pose) (1.16.0)\n", + "Requirement already satisfied: ruamel.yaml.clib>=0.2.7 in /opt/conda/envs/stac-mjx-env/lib/python3.11/site-packages (from ruamel-yaml>=0.16->hdmf<4,>=2.5.6->ndx-pose) (0.2.8)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install ndx-pose" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HG8_BB3JCHni", + "outputId": "90fcc09d-2d5e-4443-f0fb-9622fe8a5e57" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From (original): https://drive.google.com/uc?id=1nN9XWL2L-ZeSb5rqu38gC28EOVlYuaOq\n", + "From (redirected): https://drive.google.com/uc?id=1nN9XWL2L-ZeSb5rqu38gC28EOVlYuaOq&confirm=t&uuid=f6e015d8-9625-4b60-a583-959012baf7ac\n", + "To: /content/save_data_AVG.mat\n", + "100% 467M/467M [00:06<00:00, 75.4MB/s]\n" + ] + } + ], + "source": [ + "!gdown --fuzzy https://drive.google.com/file/d/1nN9XWL2L-ZeSb5rqu38gC28EOVlYuaOq/view?usp=drive_link" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3HdLigI8Cf97", + "outputId": "0b11e244-ba33-417e-9c73-469827eb628c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['__header__', '__version__', '__globals__', 'pred', 'data', 'p_max', 'sampleID', 'metadata'])\n", + "(360000, 3, 23)\n", + "(360000, 23)\n" + ] + } + ], + "source": [ + "from scipy.io import loadmat\n", + "import datetime\n", + "import numpy as np\n", + "from pynwb import NWBFile, NWBHDF5IO\n", + "from ndx_pose import PoseEstimationSeries, PoseEstimation\n", + "\n", + "\n", + "mat = loadmat(\"save_data_AVG.mat\")\n", + "print(mat.keys())\n", + "print(mat[\"data\"].shape)\n", + "print(mat[\"p_max\"].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "O7MtC-C3DFAX", + "outputId": "ccc1e9c2-1f3c-418d-f9c6-4fa1e7975e26" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "augment_brightness: [[array([[0]])]]\n", + "augment_hue: [[array([[0]])]]\n", + "com_exp: [[array([[(array(['/n/holylfs02/LABS/olveczky_lab/Everyone/dannce_rig/dannce_ephys/art/2020_12_21_1/20201221_163226_Label3D_dannce.mat'],\n", + " dtype='/dev/null || (apt update && apt install -y ffmpeg)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Model is hard coded version of synth_model.xml. TODO: Read from file.\n", + "\n", + "chaotic_pendulum = \"\"\"\n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\"\"\"\n", + "\n", + "model = mujoco.MjModel.from_xml_string(chaotic_pendulum)\n", + "data = mujoco.MjData(model)\n", + "height = 480\n", + "width = 640\n", + "\n", + "with mujoco.Renderer(model, height, width) as renderer:\n", + " mujoco.mj_forward(model, data)\n", + " renderer.update_scene(data, camera=\"fixed\")\n", + "\n", + " media.show_image(renderer.render())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "simulation: 1.74 μs/step (574458Hz)\n", + "rendering: 4.65e+03 μs/frame ( 215Hz)\n", + "\n", + "\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# setup\n", + "n_seconds = 6\n", + "framerate = 30 # Hz\n", + "n_frames = int(n_seconds * framerate)\n", + "frames = []\n", + "height = 240\n", + "width = 320\n", + "\n", + "# set initial state\n", + "mujoco.mj_resetData(model, data)\n", + "data.joint('root').qvel = 2\n", + "\n", + "# simulate and record frames\n", + "frame = 0\n", + "sim_time = 0\n", + "render_time = 0\n", + "n_steps = 0\n", + "with mujoco.Renderer(model, height, width) as renderer:\n", + " for i in range(n_frames):\n", + " while data.time * framerate < i:\n", + " tic = time.time()\n", + " mujoco.mj_step(model, data)\n", + " sim_time += time.time() - tic\n", + " n_steps += 1\n", + " tic = time.time()\n", + " renderer.update_scene(data, \"fixed\")\n", + " frame = renderer.render()\n", + " render_time += time.time() - tic\n", + " frames.append(frame)\n", + "\n", + "# print timing and play video\n", + "step_time = 1e6*sim_time/n_steps\n", + "step_fps = n_steps/sim_time\n", + "print(f'simulation: {step_time:5.3g} μs/step ({step_fps:5.0f}Hz)')\n", + "frame_time = 1e6*render_time/n_frames\n", + "frame_fps = n_frames/render_time\n", + "print(f'rendering: {frame_time:5.3g} μs/frame ({frame_fps:5.0f}Hz)')\n", + "print('\\n')\n", + "\n", + "# show video\n", + "media.show_video(frames, fps=framerate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Modify the code below for data generation. \n", + "\n", + "# PERTURBATION = 1e-7\n", + "# SIM_DURATION = 10.0 # seconds\n", + "\n", + "# # preallocate\n", + "# n_steps = int(SIM_DURATION / model.opt.timestep)\n", + "# sim_time = np.zeros(n_steps)\n", + "# x0 = np.zeros((n_steps, 3))\n", + "# x1 = np.zeros((n_steps, 3))\n", + "\n", + "# # prepare plotting axes\n", + "# _, ax = plt.subplots(2, 1, figsize=(8, 6), sharex=True)\n", + "\n", + "# # initialize\n", + "# mujoco.mj_resetData(model, data)\n", + "# data.qvel[0] = 10 # root joint velocity\n", + "# # perturb initial velocities\n", + "# #data.qvel[:] += PERTURBATION * np.random.randn(model.nv)\n", + "\n", + "# # simulate\n", + "# for i in range(n_steps):\n", + "# mujoco.mj_step(model, data)\n", + "# sim_time[i] = data.time\n", + "# #angle[i] = data.joint('root').qpos\n", + "# x0[i] = data.body('0').xpos\n", + "# #x1[i] = data.body('1').xpos\n", + "# #print(data.body('1').xpos[0])\n", + "\n", + "# # plot\n", + "# ax[0].plot(sim_time, x0[:,2])\n", + "# ax[1].plot(sim_time, x1[:,2])\n", + "\n", + "# # finalize plot\n", + "# ax[0].set_title('x0')\n", + "# ax[0].set_ylabel('m')\n", + "# ax[1].set_title('x1')\n", + "# ax[1].set_ylabel('m')\n", + "# ax[1].set_xlabel('time (s)')\n", + "# plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "stac-mjx-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/synth_model.xml b/models/synth_model.xml new file mode 100644 index 0000000..9585c88 --- /dev/null +++ b/models/synth_model.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/stac_mjx/compute_stac.py b/stac_mjx/compute_stac.py index 632cd98..24a2597 100644 --- a/stac_mjx/compute_stac.py +++ b/stac_mjx/compute_stac.py @@ -14,6 +14,7 @@ def root_optimization( mjx_model, mjx_data, kp_data: jp.ndarray, + root_kp_idx: int, lb: jp.ndarray, ub: jp.ndarray, site_idxs: jp.ndarray, @@ -50,7 +51,7 @@ def root_optimization( # necessarily exactly so. The value of 3*18 is chosen for the # rodent.xml, corresponding to the index of 'SpineL' keypoint. # For the mouse model this should be 3*5, corresponding 'Trunk' - root_kp_idx = 3 * 18 + # root_kp_idx = 3 * 18 # FLY_MODEL: # root_kp_idx = 0 q0.at[:3].set(kp_data[frame, :][root_kp_idx : root_kp_idx + 3]) diff --git a/stac_mjx/stac.py b/stac_mjx/stac.py index a03822f..96ec88b 100644 --- a/stac_mjx/stac.py +++ b/stac_mjx/stac.py @@ -80,15 +80,22 @@ def __init__(self, xml_path: str, cfg: DictConfig, kp_names: List[str]): self._mj_model.body(i).name for i in range(self._mj_model.nbody) ] - joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)] + if "ROOT_OPTIMIZATION_KEYPOINT" in self.cfg.model: + self._root_kp_idx = self._kp_names.index( + self.cfg.model.ROOT_OPTIMIZATION_KEYPOINT + ) + else: + self._root_kp_idx = -1 # Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters + joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)] self._lb, self._ub, self._part_names = _align_joint_dims( self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names ) self._indiv_parts = self.part_opt_setup() + # Generate boolean flags for keypoints included in trunk optimization. self._trunk_kps = jp.array( [n in self.cfg.model.TRUNK_OPTIMIZATION_KEYPOINTS for n in kp_names], ) @@ -113,7 +120,7 @@ def get_part_ids(parts: List) -> jp.ndarray: [any(part in name for part in parts) for name in self._part_names] ) - if self.cfg.model.INDIVIDUAL_PART_OPTIMIZATION is None: + if "INDIVIDUAL_PART_OPTIMIZATION" not in self.cfg.model: indiv_parts = [] else: indiv_parts = jp.array( @@ -224,11 +231,16 @@ def fit_offsets(self, kp_data): # Begin optimization steps # Skip root optimization if model is fixed (no free joint at root) - if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE: + if self._root_kp_idx == -1: + print( + "ROOT_OPTIMIZATION_KEYPOINT not specified, skipping Root Optimization." + ) + elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE: mjx_data = compute_stac.root_optimization( mjx_model, mjx_data, kp_data, + self._root_kp_idx, self._lb, self._ub, self._body_site_idxs, @@ -339,15 +351,20 @@ def mjx_setup(kp_data, mj_model): ) # q_phase - root - if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE: + if self._root_kp_idx == -1: + print( + "Missing or invalid ROOT_OPTIMIZATION_KEYPOINT, skipping root_optimization()" + ) + elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE: vmap_root_opt = jax.vmap( compute_stac.root_optimization, - in_axes=(0, 0, 0, None, None, None, None), + in_axes=(0, 0, 0, None, None, None, None, None), ) mjx_data = vmap_root_opt( mjx_model, mjx_data, batched_kp_data, + self._root_kp_idx, self._lb, self._ub, self._body_site_idxs, diff --git a/tests/configs/model/test_mouse.yaml b/tests/configs/model/test_mouse.yaml index 7091175..4c740c4 100644 --- a/tests/configs/model/test_mouse.yaml +++ b/tests/configs/model/test_mouse.yaml @@ -122,6 +122,8 @@ KEYPOINT_INITIAL_OFFSETS: Lisfranc_L: 0.0 0.0 0.0 MTP_R: 0.0 0.0 0.0 +ROOT_OPTIMIZATION_KEYPOINT: Trunk + TRUNK_OPTIMIZATION_KEYPOINTS: - "Trunk" - "HipL" diff --git a/tests/configs/model/test_rodent.yaml b/tests/configs/model/test_rodent.yaml index 3ec71e4..154583f 100644 --- a/tests/configs/model/test_rodent.yaml +++ b/tests/configs/model/test_rodent.yaml @@ -80,6 +80,8 @@ KEYPOINT_INITIAL_OFFSETS: WristL: 0. 0. 0.0 WristR: 0. 0. 0.0 +ROOT_OPTIMIZATION_KEYPOINT: SpineL + TRUNK_OPTIMIZATION_KEYPOINTS: - "Spine" - "Hip" diff --git a/tests/configs/model/test_rodent_label3d.yaml b/tests/configs/model/test_rodent_label3d.yaml index 66d9854..46d6af8 100644 --- a/tests/configs/model/test_rodent_label3d.yaml +++ b/tests/configs/model/test_rodent_label3d.yaml @@ -8,7 +8,6 @@ N_FRAMES_PER_CLIP: 360 # presumed to be derived from label3d: KP_NAMES_LABEL3D_PATH: "tests/data/rat23.mat" - # The model sites used to register the keypoints. KEYPOINT_MODEL_PAIRS: AnkleL: lower_leg_L @@ -61,6 +60,8 @@ KEYPOINT_INITIAL_OFFSETS: WristL: 0. 0. 0.0 WristR: 0. 0. 0.0 +ROOT_OPTIMIZATION_KEYPOINT: SpineL + TRUNK_OPTIMIZATION_KEYPOINTS: - "Spine" - "Hip" diff --git a/tests/configs/model/test_rodent_less_kp_names.yaml b/tests/configs/model/test_rodent_less_kp_names.yaml index 699d137..6cdf6ac 100644 --- a/tests/configs/model/test_rodent_less_kp_names.yaml +++ b/tests/configs/model/test_rodent_less_kp_names.yaml @@ -59,6 +59,8 @@ KEYPOINT_INITIAL_OFFSETS: WristL: 0. 0. 0.0 WristR: 0. 0. 0.0 +ROOT_OPTIMIZATION_KEYPOINT: SpineL + TRUNK_OPTIMIZATION_KEYPOINTS: - "Spine" - "Hip" diff --git a/tests/configs/model/test_rodent_no_kp_names.yaml b/tests/configs/model/test_rodent_no_kp_names.yaml index 1e64589..bab8a99 100644 --- a/tests/configs/model/test_rodent_no_kp_names.yaml +++ b/tests/configs/model/test_rodent_no_kp_names.yaml @@ -58,6 +58,8 @@ KEYPOINT_INITIAL_OFFSETS: WristL: 0. 0. 0.0 WristR: 0. 0. 0.0 +ROOT_OPTIMIZATION_KEYPOINT: SpineL + TRUNK_OPTIMIZATION_KEYPOINTS: - "Spine" - "Hip" diff --git a/tests/data/test_synth_1_frames.nwb b/tests/data/test_synth_1_frames.nwb new file mode 100644 index 0000000..b7a820d Binary files /dev/null and b/tests/data/test_synth_1_frames.nwb differ