Skip to content

Commit

Permalink
Proper bidirectional implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 10, 2024
1 parent 1ccbeba commit 831de55
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 9 deletions.
24 changes: 19 additions & 5 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2357,17 +2357,31 @@ defmodule Axon do
"""
def bidirectional(%Axon{} = input, forward_fun, merge_fun, opts \\ [])
when is_function(forward_fun, 1) and is_function(merge_fun, 2) do
opts = Keyword.validate!(opts, axis: 1)
opts = Keyword.validate!(opts, [:name, axis: 1])

forward_out = forward_fun.(input)
fun =
Axon.block(
fn x ->
Axon.container(forward_fun.(x))
end,
name: opts[:name]
)

forward_out = fun.(input)

backward_out =
input
|> Axon.nx(&Nx.reverse(&1, axes: [opts[:axis]]))
|> forward_fun.()
|> deep_new(&Axon.nx(&1, fn x -> Nx.reverse(x, axes: [opts[:axis]]) end))
|> fun.()
|> Axon.nx(fn x ->
deep_new(x, &Nx.reverse(&1, axes: [opts[:axis]]))
end)

deep_merge(forward_out, backward_out, merge_fun)
{forward_out, backward_out}
|> Axon.container()
|> Axon.nx(fn {forward, backward} ->
deep_merge(forward, backward, merge_fun)
end)
end

@doc """
Expand Down
4 changes: 2 additions & 2 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ defmodule Axon.ModelState do
} = model_state,
updated_parameters,
updated_state \\ %{}
) do
) do
updated_state =
state
|> tree_diff(frozen)
Expand Down Expand Up @@ -215,7 +215,7 @@ defmodule Axon.ModelState do
Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1)))
end

defp tree_get(data, access) when is_map(access) do
defp tree_get(data, access) when is_map(access) do
Enum.reduce(access, %{}, fn {key, value}, acc ->
tree = tree_get(data[key], value)
Map.put(acc, key, tree)
Expand Down
29 changes: 29 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5650,4 +5650,33 @@ defmodule CompilerTest do
predict_fn.(params, input)
end
end

describe "bidirectional" do
test "works properly with LSTMs" do
input = Axon.input("input")

model =
input
|> Axon.embedding(10, 16)
|> Axon.bidirectional(
&Axon.lstm(&1, 32, name: "lstm"),
&Nx.concatenate([&1, &2], axis: 1),
name: "bidirectional"
)
|> Axon.nx(&elem(&1, 0))

{init_fn, predict_fn} = Axon.build(model)

input = Nx.broadcast(1, {1, 10})

assert %ModelState{
data: %{
"bidirectional" => %{"lstm" => _}
}
} = params = init_fn.(input, ModelState.empty())

out = predict_fn.(params, input)
assert Nx.shape(out) == {1, 20, 32}
end
end
end
8 changes: 6 additions & 2 deletions test/axon/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ defmodule Axon.IntegrationTest do
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params

assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
unquote(Macro.escape(policy)).params
end)
end

Expand Down Expand Up @@ -536,7 +538,9 @@ defmodule Axon.IntegrationTest do
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params

assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
unquote(Macro.escape(policy)).params
end)
end
end
Expand Down

0 comments on commit 831de55

Please sign in to comment.