Skip to content

Commit

Permalink
Fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 14, 2024
1 parent 4851084 commit 603818f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,8 @@ defmodule Axon.Loop do
end)

model_out = forward_model_fn.(model_state, inp)

{scaled_loss, unscaled_loss} =
tar
|> loss_fn.(model_out.prediction)
|> then(fn loss ->
scaled = scale_loss.(loss, loss_scale_state)
{scaled, loss}
end)
unscaled_loss = loss_fn.(tar, model_out.prediction)
scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)

{model_out, scaled_loss, unscaled_loss}
end
Expand Down
2 changes: 1 addition & 1 deletion test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ defmodule Axon.LoopTest do

test "trainer/3 returns a supervised training loop with custom loss" do
model = Axon.input("input", shape: {nil, 1})
custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.BinaryBackend) end
custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.Defn.Expr) end

assert %Loop{init: init_fn, step: update_fn, output_transform: transform} =
Loop.trainer(model, custom_loss_fn, :adam)
Expand Down

0 comments on commit 603818f

Please sign in to comment.