diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 677c2668d4e9..18b823ccccb3 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -383,7 +383,7 @@ def set_box_aspect(self, aspect, *, zoom=1): # of the axes in mpl3.8. aspect *= 1.8294640721620434 * 25/24 * zoom / np.linalg.norm(aspect) - self._box_aspect = aspect + self._box_aspect = self._roll_to_vertical(aspect, reverse=True) self.stale = True def apply_aspect(self, position=None): @@ -1191,9 +1191,23 @@ def set_proj_type(self, proj_type, focal_length=None): f"None for proj_type = {proj_type}") self._focal_length = np.inf - def _roll_to_vertical(self, arr): - """Roll arrays to match the different vertical axis.""" - return np.roll(arr, self._vertical_axis - 2) + def _roll_to_vertical( + self, arr: "np.typing.ArrayLike", reverse: bool = False + ) -> np.ndarray: + """ + Roll arrays to match the different vertical axis. + + Parameters + ---------- + arr : ArrayLike + Array to roll. + reverse : bool, default: False + Reverse the direction of the roll. + """ + if reverse: + return np.roll(arr, (self._vertical_axis - 2) * -1) + else: + return np.roll(arr, (self._vertical_axis - 2)) def get_proj(self): """Create the projection matrix from the current viewing position.""" diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index c339e35e903c..34f10cb9fd63 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -2301,6 +2301,24 @@ def test_on_move_vertical_axis(vertical_axis: str) -> None: ) +@pytest.mark.parametrize( + "vertical_axis, aspect_expected", + [ + ("x", [1.190476, 0.892857, 1.190476]), + ("y", [0.892857, 1.190476, 1.190476]), + ("z", [1.190476, 1.190476, 0.892857]), + ], +) +def test_set_box_aspect_vertical_axis(vertical_axis, aspect_expected): + ax = plt.subplot(1, 1, 1, projection="3d") + ax.view_init(elev=0, azim=0, roll=0, vertical_axis=vertical_axis) + ax.figure.canvas.draw() + + ax.set_box_aspect(None) + + np.testing.assert_allclose(aspect_expected, ax._box_aspect, rtol=1e-6) + + @image_comparison(baseline_images=['arc_pathpatch.png'], remove_text=True, style='mpl20')