From 831de556227980e62bb2e22f3884a440edd9ad75 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Fri, 10 May 2024 15:08:03 -0400 Subject: [PATCH] Proper bidirectional implementation --- lib/axon.ex | 24 +++++++++++++++++++----- lib/axon/model_state.ex | 4 ++-- test/axon/compiler_test.exs | 29 +++++++++++++++++++++++++++++ test/axon/integration_test.exs | 8 ++++++-- 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 458f1384..6c30ce77 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -2357,17 +2357,31 @@ defmodule Axon do """ def bidirectional(%Axon{} = input, forward_fun, merge_fun, opts \\ []) when is_function(forward_fun, 1) and is_function(merge_fun, 2) do - opts = Keyword.validate!(opts, axis: 1) + opts = Keyword.validate!(opts, [:name, axis: 1]) - forward_out = forward_fun.(input) + fun = + Axon.block( + fn x -> + Axon.container(forward_fun.(x)) + end, + name: opts[:name] + ) + + forward_out = fun.(input) backward_out = input |> Axon.nx(&Nx.reverse(&1, axes: [opts[:axis]])) - |> forward_fun.() - |> deep_new(&Axon.nx(&1, fn x -> Nx.reverse(x, axes: [opts[:axis]]) end)) + |> fun.() + |> Axon.nx(fn x -> + deep_new(x, &Nx.reverse(&1, axes: [opts[:axis]])) + end) - deep_merge(forward_out, backward_out, merge_fun) + {forward_out, backward_out} + |> Axon.container() + |> Axon.nx(fn {forward, backward} -> + deep_merge(forward, backward, merge_fun) + end) end @doc """ diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index 8604789f..81e14e7f 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -23,7 +23,7 @@ defmodule Axon.ModelState do } = model_state, updated_parameters, updated_state \\ %{} - ) do + ) do updated_state = state |> tree_diff(frozen) @@ -215,7 +215,7 @@ defmodule Axon.ModelState do Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1))) end - defp tree_get(data, access) when is_map(access) do + defp tree_get(data, access) when is_map(access) do Enum.reduce(access, %{}, fn {key, value}, acc -> tree = tree_get(data[key], value) Map.put(acc, key, tree) diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index 37e0922e..a22a49da 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5650,4 +5650,33 @@ defmodule CompilerTest do predict_fn.(params, input) end end + + describe "bidirectional" do + test "works properly with LSTMs" do + input = Axon.input("input") + + model = + input + |> Axon.embedding(10, 16) + |> Axon.bidirectional( + &Axon.lstm(&1, 32, name: "lstm"), + &Nx.concatenate([&1, &2], axis: 1), + name: "bidirectional" + ) + |> Axon.nx(&elem(&1, 0)) + + {init_fn, predict_fn} = Axon.build(model) + + input = Nx.broadcast(1, {1, 10}) + + assert %ModelState{ + data: %{ + "bidirectional" => %{"lstm" => _} + } + } = params = init_fn.(input, ModelState.empty()) + + out = predict_fn.(params, input) + assert Nx.shape(out) == {1, 20, 32} + end + end end diff --git a/test/axon/integration_test.exs b/test/axon/integration_test.exs index f0678f44..65b2a405 100644 --- a/test/axon/integration_test.exs +++ b/test/axon/integration_test.exs @@ -485,7 +485,9 @@ defmodule Axon.IntegrationTest do assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} - assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + + assert Nx.type(model_state.data["dense_0"]["kernel"]) == + unquote(Macro.escape(policy)).params end) end @@ -536,7 +538,9 @@ defmodule Axon.IntegrationTest do assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60) assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"]) assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2} - assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params + + assert Nx.type(model_state.data["dense_0"]["kernel"]) == + unquote(Macro.escape(policy)).params end) end end