From f3d5bf964dc36b212fd06636fbce5d6d1f2e8d17 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 05:31:05 -0700 Subject: [PATCH] Fix iteration counts (#572) --- lib/axon/loop.ex | 86 ++++++++++++++++++++++++++++++++++------- lib/axon/loop/state.ex | 4 +- test/axon/loop_test.exs | 23 +++++------ 3 files changed, 86 insertions(+), 27 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 232ab768..2637258a 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -690,7 +690,7 @@ defmodule Axon.Loop do loop |> log(&supervised_log_message_fn/1, event: :iteration_completed, - filter: [every: log_interval] + filter: [every: {:epoch, log_interval}] ) |> log(fn _ -> "\n" end, event: :epoch_completed) else @@ -1912,6 +1912,29 @@ defmodule Axon.Loop do end) end + defp update_counts(%State{event_counts: event_counts} = state, event) + when event in [:iteration_started, :iteration_completed] do + updated_counts = + Map.update(event_counts, event, %{total: 1, epoch: 1}, fn total_and_epoch -> + total_and_epoch + |> Map.update!(:total, &(&1 + 1)) + |> Map.update!(:epoch, &(&1 + 1)) + end) + + %{state | event_counts: updated_counts} + end + + defp update_counts(%State{event_counts: event_counts} = state, event) + when event in [:epoch_halted, :epoch_completed] do + updated_counts = + event_counts + |> Map.update(:iteration_started, %{total: 0, epoch: 0}, &%{&1 | epoch: 0}) + |> Map.update(:iteration_completed, %{total: 0, epoch: 0}, &%{&1 | epoch: 0}) + |> Map.update(event, 1, &(&1 + 1)) + + %{state | event_counts: updated_counts} + end + defp update_counts(%State{event_counts: event_counts} = state, event) do %{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)} end @@ -2165,29 +2188,53 @@ defmodule Axon.Loop do :first -> fn %State{event_counts: counts}, event -> - counts[event] == 1 + case counts[event] do + 1 -> true + %{total: 1} -> true + _ -> false + end end filters when is_list(filters) -> Enum.reduce(filters, fn _, _ -> true end, fn + {:every, {key, n}}, acc -> + fn state, event -> + acc.(state, event) and filter_every_n(state, event, key, n) + end + {:every, n}, acc -> fn state, event -> - acc.(state, event) and filter_every_n(state, event, n) + acc.(state, event) and filter_every_n(state, event, :total, n) + end + + {:before, {key, n}}, acc -> + fn state, event -> + acc.(state, event) and filter_before_n(state, event, key, n) end {:before, n}, acc -> fn state, event -> - acc.(state, event) and filter_before_n(state, event, n) + acc.(state, event) and filter_before_n(state, event, :total, n) + end + + {:after, {key, n}}, acc -> + fn state, event -> + acc.(state, event) and filter_after_n(state, event, key, n) end {:after, n}, acc -> fn state, event -> - acc.(state, event) and filter_after_n(state, event, n) + acc.(state, event) and filter_after_n(state, event, :total, n) + end + + {:once, {key, n}}, acc -> + fn state, event -> + acc.(state, event) and filter_once_n(state, event, key, n) end {:once, n}, acc -> fn state, event -> - acc.(state, event) and filter_once_n(state, event, n) + acc.(state, event) and filter_once_n(state, event, :total, n) end end) @@ -2204,20 +2251,31 @@ defmodule Axon.Loop do end end - defp filter_every_n(%State{event_counts: counts}, event, n) do - rem(counts[event] - 1, n) == 0 + defp filter_every_n(%State{event_counts: counts}, event, key, n) do + count = get_count(counts, event, key) + rem(count - 1, n) == 0 end - defp filter_after_n(%State{event_counts: counts}, event, n) do - counts[event] > n + defp filter_after_n(%State{event_counts: counts}, event, key, n) do + count = get_count(counts, event, key) + count > n end - defp filter_before_n(%State{event_counts: counts}, event, n) do - counts[event] < n + defp filter_before_n(%State{event_counts: counts}, event, key, n) do + count = get_count(counts, event, key) + count < n end - defp filter_once_n(%State{event_counts: counts}, event, n) do - counts[event] == n + defp filter_once_n(%State{event_counts: counts}, event, key, n) do + count = get_count(counts, event, key) + count == n + end + + defp get_count(counts, event, key) do + case counts[event] do + %{^key => count} -> count + count -> count + end end # JIT-compiles the given function if jit_compile? is true diff --git a/lib/axon/loop/state.ex b/lib/axon/loop/state.ex index eaccef22..56a7c9cb 100644 --- a/lib/axon/loop/state.ex +++ b/lib/axon/loop/state.ex @@ -60,8 +60,8 @@ defmodule Axon.Loop.State do event_counts: %{ started: 0, epoch_started: 0, - iteration_started: 0, - iteration_completed: 0, + iteration_started: %{total: 0, epoch: 0}, + iteration_completed: %{total: 0, epoch: 0}, epoch_completed: 0, epoch_halted: 0, halted: 0, diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 34c002ab..fcd5c5d3 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -636,8 +636,8 @@ defmodule Axon.LoopTest do started: 1, epoch_started: 1, epoch_completed: 1, - iteration_started: 10, - iteration_completed: 10 + iteration_started: %{total: 10, epoch: 0}, + iteration_completed: %{total: 10, epoch: 0} }} assert_received {:epoch_started, @@ -645,8 +645,8 @@ defmodule Axon.LoopTest do started: 1, epoch_started: 2, epoch_completed: 1, - iteration_started: 10, - iteration_completed: 10 + iteration_started: %{total: 10, epoch: 0}, + iteration_completed: %{total: 10, epoch: 0} }} assert_received {:epoch_completed, @@ -654,8 +654,8 @@ defmodule Axon.LoopTest do started: 1, epoch_started: 2, epoch_completed: 2, - iteration_started: 20, - iteration_completed: 20 + iteration_started: %{total: 20, epoch: 0}, + iteration_completed: %{total: 20, epoch: 0} }} refute_received _ @@ -786,7 +786,7 @@ defmodule Axon.LoopTest do test "supports function filter" do fun = fn - %{event_counts: counts}, event -> counts[event] == 5 + %{event_counts: counts}, event -> counts[event][:total] == 5 end run_dummy_loop!(:iteration_started, fun, 5, 10) @@ -854,18 +854,19 @@ defmodule Axon.LoopTest do test "saves a checkpoint on custom events", %{loop: loop} do data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5) - assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} = + assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: %{total: 15}}} = loop |> Map.put(:output_transform, & &1) - |> Loop.checkpoint(event: :iteration_completed, filter: [every: 2]) + |> Loop.checkpoint(event: :iteration_completed, filter: [every: {:epoch, 2}]) |> Loop.run(data, Axon.ModelState.empty(), epochs: 3) assert [ "checkpoint_0_0.ckpt", "checkpoint_0_2.ckpt", "checkpoint_0_4.ckpt", - "checkpoint_1_1.ckpt", - "checkpoint_1_3.ckpt", + "checkpoint_1_0.ckpt", + "checkpoint_1_2.ckpt", + "checkpoint_1_4.ckpt", "checkpoint_2_0.ckpt", "checkpoint_2_2.ckpt", "checkpoint_2_4.ckpt"