Skip to content

Commit

Permalink
Relax slerp.interpolate() input shape test and make it coherent with …
Browse files Browse the repository at this point in the history
…docstring.

PiperOrigin-RevId: 503185356
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Jan 19, 2023
1 parent 98c0d63 commit 6dc8101
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions tensorflow_graphics/math/interpolation/slerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def quaternion_weights(
quaternions in its last dimension.
quaternion2: A tensor of shape `[A1, ... , An, 4]` storing normalized
quaternions in its last dimension.
percent: A `float` or a tensor with a shape broadcastable to the shape `[A1,
... , An]`.
percent: A `float` or tensor with shape broadcastable to the shape of input
vectors.
eps: A `float` used to make operations safe. When left as None, the function
automatically picks the best epsilon based on the dtype and the operation.
name: A name for this op. Defaults to "quaternion_weights".
Expand All @@ -198,7 +198,7 @@ def quaternion_weights(
tensor=quaternion2, tensor_name="quaternion2", has_dim_equals=(-1, 4))
shape.compare_batch_dimensions(
tensors=(quaternion1, quaternion2, percent),
last_axes=(-2, -2, -1),
last_axes=-1,
broadcast_compatible=True,
tensor_names=("quaternion1", "quaternion2", "percent"))
quaternion1 = asserts.assert_normalized(quaternion1)
Expand Down Expand Up @@ -266,7 +266,7 @@ def vector_weights(vector1: type_alias.TensorLike,
tensor_names=("vector1", "vector2"))
shape.compare_batch_dimensions(
tensors=(vector1, vector2, percent),
last_axes=(-2, -2, -1),
last_axes=-1,
broadcast_compatible=True,
tensor_names=("vector1", "vector2", "percent"))
normalized1 = tf.nn.l2_normalize(vector1, axis=-1)
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_graphics/math/interpolation/tests/slerp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_unnormalized_quaternion_weights_exception_raised(self):
@parameterized.parameters(
((4,), (4,), (1,)),
((None, 4), (None, 4), (None, 1)),
((None, 4), (None, 4), (None, 4)),
((5, 1, 4), (5, 1, 4), (3, 1)),
)
def test_quaternion_weights_exception_not_raised(self, *shapes):
"""Tests that valid input shapes do not raise exceptions for qslerp."""
Expand All @@ -140,6 +140,8 @@ def test_quaternion_weights_exception_not_raised(self, *shapes):
(1,)),
("Not all batch dimensions are broadcast-compatible.", (1, 4), (3, 4),
(2,)),
("Not all batch dimensions are broadcast-compatible.", (5, 1, 4),
(5, 1, 4), (3,)),
)
def test_quaternion_weights_exception_raised(self, error_msg, *shapes):
"""Tests that the shape exceptions are properly raised for qslerp."""
Expand Down

0 comments on commit 6dc8101

Please sign in to comment.