Skip to content

Commit

Permalink
Fix missing state in training (#579)
Browse files Browse the repository at this point in the history
* Inspect

* Inspect more

* Raise on key

* inspect

* Fix?

* Again

* Again

* Fix axon

* It works
  • Loading branch information
seanmor5 authored May 30, 2024
1 parent efd4c1f commit 57cd12f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ defmodule Axon.Loop do
is set, the loop will raise on any cache miss during the training loop. Defaults
to true.
* `:force_garbage_collect?` - whether or not to force garbage collection after each
* `:force_garbage_collection?` - whether or not to force garbage collection after each
iteration. This may help avoid OOMs when training large models, but it will slow
training down.
Expand Down
20 changes: 17 additions & 3 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,27 @@ defmodule Axon.ModelState do
end

defp tree_get(data, access) when is_list(access) do
Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1)))
Enum.reduce(access, %{}, fn key, acc ->
case data do
%{^key => val} ->
Map.put(acc, key, val)

%{} ->
acc
end
end)
end

defp tree_get(data, access) when is_map(access) do
Enum.reduce(access, %{}, fn {key, value}, acc ->
tree = tree_get(data[key], value)
Map.put(acc, key, tree)
case data do
%{^key => val} ->
tree = tree_get(val, value)
Map.put(acc, key, tree)

%{} ->
acc
end
end)
end

Expand Down

0 comments on commit 57cd12f

Please sign in to comment.