diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 6d16d01c..e21d8666 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1556,7 +1556,7 @@ defmodule Axon.Loop do is set, the loop will raise on any cache miss during the training loop. Defaults to true. - * `:force_garbage_collect?` - whether or not to force garbage collection after each + * `:force_garbage_collection?` - whether or not to force garbage collection after each iteration. This may help avoid OOMs when training large models, but it will slow training down. diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 23ac10be..8eede9a6 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -222,13 +222,27 @@ defmodule Axon.ModelState do end defp tree_get(data, access) when is_list(access) do - Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1))) + Enum.reduce(access, %{}, fn key, acc -> + case data do + %{^key => val} -> + Map.put(acc, key, val) + + %{} -> + acc + end + end) end defp tree_get(data, access) when is_map(access) do Enum.reduce(access, %{}, fn {key, value}, acc -> - tree = tree_get(data[key], value) - Map.put(acc, key, tree) + case data do + %{^key => val} -> + tree = tree_get(val, value) + Map.put(acc, key, tree) + + %{} -> + acc + end end) end