diff --git a/clrs/_src/evaluation_test.py b/clrs/_src/evaluation_test.py index 4f83ac39..b7fde2ac 100644 --- a/clrs/_src/evaluation_test.py +++ b/clrs/_src/evaluation_test.py @@ -37,11 +37,11 @@ def test_reduce_permutations(self): perm = probing.DataPoint(name='test', type_=specs.Type.PERMUTATION_POINTER, location=specs.Location.NODE, - data=jax.nn.one_hot(pred, n)) + data=np.asarray(jax.nn.one_hot(pred, n))) mask = probing.DataPoint(name='test_mask', type_=specs.Type.MASK_ONE, location=specs.Location.NODE, - data=jax.nn.one_hot(heads, n)) + data=np.asarray(jax.nn.one_hot(heads, n))) output = evaluation.fuse_perm_and_mask(perm=perm, mask=mask) expected_output = np.array(pred) expected_output[np.arange(b), heads] = heads diff --git a/clrs/_src/losses_test.py b/clrs/_src/losses_test.py index e7600cd7..44a181fa 100644 --- a/clrs/_src/losses_test.py +++ b/clrs/_src/losses_test.py @@ -152,7 +152,7 @@ def test_hint_loss(self, algo): nb_nodes=nb_nodes, ) - full_preds = pred[1:] + full_preds = list(pred[1:]) full_hint_loss = losses.hint_loss( truth=_mask_datapoint(truth_full, 1, t_axis=0), preds=full_preds,