Skip to content

Commit

Permalink
Fixed type errors to unblock an internal type annotations refactoring…
Browse files Browse the repository at this point in the history
… in JAX

Some JAX internal used Any instead of Array or in their type annotations.
jax-ml/jax#17760 changed these to alias jax.Array and
uncovered type errors fixed here.

PiperOrigin-RevId: 569994709
  • Loading branch information
superbobry authored and copybara-github committed Oct 2, 2023
1 parent 7e7b66e commit 64e0169
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions clrs/_src/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def test_reduce_permutations(self):
perm = probing.DataPoint(name='test',
type_=specs.Type.PERMUTATION_POINTER,
location=specs.Location.NODE,
data=jax.nn.one_hot(pred, n))
data=np.asarray(jax.nn.one_hot(pred, n)))
mask = probing.DataPoint(name='test_mask',
type_=specs.Type.MASK_ONE,
location=specs.Location.NODE,
data=jax.nn.one_hot(heads, n))
data=np.asarray(jax.nn.one_hot(heads, n)))
output = evaluation.fuse_perm_and_mask(perm=perm, mask=mask)
expected_output = np.array(pred)
expected_output[np.arange(b), heads] = heads
Expand Down
2 changes: 1 addition & 1 deletion clrs/_src/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_hint_loss(self, algo):
nb_nodes=nb_nodes,
)

full_preds = pred[1:]
full_preds = list(pred[1:])
full_hint_loss = losses.hint_loss(
truth=_mask_datapoint(truth_full, 1, t_axis=0),
preds=full_preds,
Expand Down

0 comments on commit 64e0169

Please sign in to comment.