From 603818f784d94b8758a65b7b14766841e092d616 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 08:43:37 -0400 Subject: [PATCH] Fix failing test --- lib/axon/loop.ex | 10 ++-------- test/axon/loop_test.exs | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index ca752482..6d16d01c 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -367,14 +367,8 @@ defmodule Axon.Loop do end) model_out = forward_model_fn.(model_state, inp) - - {scaled_loss, unscaled_loss} = - tar - |> loss_fn.(model_out.prediction) - |> then(fn loss -> - scaled = scale_loss.(loss, loss_scale_state) - {scaled, loss} - end) + unscaled_loss = loss_fn.(tar, model_out.prediction) + scaled_loss = scale_loss.(unscaled_loss, loss_scale_state) {model_out, scaled_loss, unscaled_loss} end diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index fcd5c5d3..d001f996 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -62,7 +62,7 @@ defmodule Axon.LoopTest do test "trainer/3 returns a supervised training loop with custom loss" do model = Axon.input("input", shape: {nil, 1}) - custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.BinaryBackend) end + custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.Defn.Expr) end assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = Loop.trainer(model, custom_loss_fn, :adam)