diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 5db0d7ab..e16723d2 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5382,7 +5382,7 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == Axon.Layers.dense(input, k, b) + assert_equal(predict_fn.(params, input), Axon.Layers.dense(input, k, b)) end test "predicts correctly with single dense, used twice" do @@ -5400,8 +5400,10 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == - input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) + assert_equal( + predict_fn.(params, input), + input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b) + ) end test "predicts correctly with multiple dense, used once" do @@ -5432,7 +5434,7 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2)) end test "predicts correctly with multiple dense, used twice" do @@ -5471,7 +5473,7 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2) + assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2)) end test "predicts correctly with multiple blocks in network" do @@ -5502,7 +5504,7 @@ defmodule CompilerTest do input = random({1, 1}) - assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2) + assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2)) end test "predicts correctly with block inside of a block" do @@ -5535,7 +5537,7 @@ defmodule CompilerTest do } = params = init_fn.(Nx.template({1, 1}, :f32), %{}) input = random({1, 1}) - assert predict_fn.(params, input) == actual_predict_fn.(input, k, b) + assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b)) end end