Skip to content

Commit

Permalink
test: fixes Axon.block tests
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Oct 19, 2023
1 parent 235a030 commit ecd4773
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5382,7 +5382,7 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) == Axon.Layers.dense(input, k, b)
assert_equal(predict_fn.(params, input), Axon.Layers.dense(input, k, b))
end

test "predicts correctly with single dense, used twice" do
Expand All @@ -5400,8 +5400,10 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) ==
input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b)
assert_equal(
predict_fn.(params, input),
input |> Axon.Layers.dense(k, b) |> Axon.Layers.dense(k, b)
)
end

test "predicts correctly with multiple dense, used once" do
Expand Down Expand Up @@ -5432,7 +5434,7 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2)
assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
end

test "predicts correctly with multiple dense, used twice" do
Expand Down Expand Up @@ -5471,7 +5473,7 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) == expected_predict_fn.(input, k1, b1, k2, b2)
assert_equal(predict_fn.(params, input), expected_predict_fn.(input, k1, b1, k2, b2))
end

test "predicts correctly with multiple blocks in network" do
Expand Down Expand Up @@ -5502,7 +5504,7 @@ defmodule CompilerTest do

input = random({1, 1})

assert predict_fn.(params, input) == actual_predict_fn.(input, k1, b1, k2, b2)
assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2))
end

test "predicts correctly with block inside of a block" do
Expand Down Expand Up @@ -5535,7 +5537,7 @@ defmodule CompilerTest do
} = params = init_fn.(Nx.template({1, 1}, :f32), %{})

input = random({1, 1})
assert predict_fn.(params, input) == actual_predict_fn.(input, k, b)
assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b))
end
end

Expand Down

0 comments on commit ecd4773

Please sign in to comment.