diff --git a/lib/axon.ex b/lib/axon.ex index e7bd58fc..40e44a43 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -279,6 +279,8 @@ defmodule Axon do alias __MODULE__, as: Axon alias Axon.Parameter + import Axon.Shared + require Logger @type t :: %__MODULE__{} @@ -380,15 +382,6 @@ defmodule Axon do } end - defp split_inputs(:container, [inputs]) do - {inputs, cache} = - deep_map_reduce(inputs, %{}, fn %Axon{output: id, nodes: nodes}, cache -> - {id, Map.merge(nodes, cache)} - end) - - {[inputs], [], [:layer], cache} - end - defp split_inputs(_op, inputs) do Enum.reduce(inputs, {[], [], [], %{}}, fn %Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} -> @@ -704,62 +697,47 @@ defmodule Axon do @doc type: :special def container(container, opts \\ []) do opts = Keyword.validate!(opts, [:name, :meta]) - - layer(:container, [container], name: opts[:name], meta: opts[:meta], op_name: :container) + {structure_fn, nodes} = destructure(container) + layer(structure_fn, nodes, name: opts[:name], meta: opts[:meta], op_name: :container) end - # TODO: This should not be duplicated - defp deep_new(%Nx.Tensor{} = x, fun), do: fun.(x) - - defp deep_new(x, fun) when is_number(x), do: fun.(x) - - defp deep_new(map, fun) do - {cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun)) - cont + defp destructure(container) do + {structure, {nodes, _}} = recur_destructure(container, {[], 0}) + fun = restructure(length(nodes) + 1, structure) + {fun, Enum.reverse(nodes)} end - defp recur_traverse(item, :ok, fun) do - case item do - %Axon{} = t -> - {fun.(t), :ok} - - %{axon: :axon} = t -> - {fun.(t), :ok} + defp recur_destructure(container, acc) do + Nx.Container.traverse(container, acc, fn value, {leaves, idx} -> + case value do + %Axon{} = leaf -> + {idx, {[leaf | leaves], idx + 1}} - container -> - {deep_new(container, fun), :ok} - end + container -> + recur_destructure(container, {leaves, idx}) + end + end) end - defp deep_merge(left, right, fun) do - case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do - {merged, []} -> - merged + for i <- 0..128 do + args = Macro.generate_arguments(i, __MODULE__) - {_merged, _leftover} -> - raise ArgumentError, - "unable to merge arguments with incompatible" <> - " structure" + defp restructure(unquote(i), structure) do + fn unquote_splicing(args) -> + args_tuple = {unquote_splicing(args)} + {container, :ok} = recur_restructure(structure, args_tuple) + container + end end end - defp leaves(container) do - container - |> Nx.Container.reduce([], fn x, acc -> [x | acc] end) - |> Enum.reverse() - end - - defp recur_merge(left, [right | right_leaves], fun) do - case {left, right} do - {%Nx.Tensor{} = left, %Nx.Tensor{} = right} -> - {fun.(left, right), right_leaves} - - {%Axon{} = left, %Axon{} = right} -> - {fun.(left, right), right_leaves} - - {left, right} -> - {deep_merge(left, right, fun), right_leaves} - end + defp recur_restructure(structure, args_tuple) do + Nx.Container.traverse(structure, :ok, fn value, :ok -> + case value do + idx when is_integer(idx) -> {elem(args_tuple, idx), :ok} + container -> recur_restructure(container, args_tuple) + end + end) end @doc """ @@ -3644,35 +3622,31 @@ defmodule Axon do end @doc """ - Returns a model's output shape from the given input + Returns a model's output template from the given input template. + + The output template gives you access to the output shape + and type of the given input graph. """ @doc type: :graph def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do {init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false]) - out = + inputs = + case inputs do + %Nx.Tensor{} = input -> Nx.to_template(input) + inputs when is_map(inputs) -> Map.new(inputs, fn {k, v} -> {k, Nx.to_template(v)} end) + end + + fun = Nx.Defn.jit( fn inputs -> forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs) end, compiler: Axon.Defn - ).(inputs) - - safe_shape(out) - end - - defp safe_shape(container_or_tensor) do - case container_or_tensor do - %Axon.None{} = none -> - none - - %Nx.Tensor{} = tensor -> - Nx.shape(tensor) + ) - container -> - deep_new(container, &Nx.shape/1) - end + deep_new(apply(fun, [inputs]), &Nx.to_template/1) end @doc """ @@ -3783,74 +3757,17 @@ defmodule Axon do if MapSet.member?(visited, id) do {acc, visited} else - %{op: op, parent: parents} = parent = nodes[id] + %{parent: parents} = parent = nodes[id] {acc, visited} = - case op do - :container -> - [container] = parents - - deep_reduce(container, {acc, visited}, fn pid, {acc, visited} -> - traverse_nodes(pid, nodes, acc, visited) - end) - - _ -> - Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} -> - traverse_nodes(pid, nodes, acc, visited) - end) - end + Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} -> + traverse_nodes(pid, nodes, acc, visited) + end) {[parent | acc], MapSet.put(visited, id)} end end - # TODO: Do not duplicate - defp deep_reduce(item, acc, fun) when is_integer(item) do - fun.(item, acc) - end - - defp deep_reduce(map, acc, fun) do - Nx.Container.reduce(map, acc, &recur_deep_reduce(&1, &2, fun)) - end - - defp recur_deep_reduce(value, acc, fun) do - case value do - %Axon{} = val -> - fun.(val, acc) - - %Nx.Tensor{} = val -> - fun.(val, acc) - - %{axon: :axon} = val -> - fun.(val, acc) - - val when is_integer(val) -> - fun.(val, acc) - - val -> - deep_reduce(val, acc, fun) - end - end - - defp deep_map_reduce(leaf, acc, fun) when is_integer(leaf), do: fun.(leaf, acc) - - defp deep_map_reduce(container, acc, fun) do - Nx.Container.traverse(container, acc, &recur_deep_map_reduce(&1, &2, fun)) - end - - defp recur_deep_map_reduce(leaf, acc, fun) do - case leaf do - %Axon{} = leaf -> - fun.(leaf, acc) - - %Nx.Tensor{} = leaf -> - fun.(leaf, acc) - - container -> - deep_map_reduce(container, acc, fun) - end - end - @doc """ Pops the top node off of the graph. diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 4cf671ae..ae449a99 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -40,7 +40,6 @@ defmodule Axon.Compiler do @moduledoc false require Logger - import Axon.Shared alias Axon.StatefulOutput ## Init JIT Compilation @@ -549,72 +548,6 @@ defmodule Axon.Compiler do {id, model_funs, cache, op_counts, block_cache, model_state_meta} end - defp recur_model_funs( - %Axon.Node{id: id, op: :container, parent: [parents]}, - nodes, - cache_and_counts, - config - ) do - {parent_ids, {cache, op_counts, block_cache, model_state_meta}} = - deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config)) - - op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end) - - predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> - {input, {state, result_cache, none?}} = - deep_map_reduce( - parent_ids, - {state, result_cache, false}, - fn parent_id, {state, result_cache, none?} -> - {input, {state, result_cache}} = - call_predict_cache( - parent_id, - params, - inputs, - state, - cache, - result_cache, - fn_stacktrace - ) - - none? = none? or propagating_none?(input) - {input, {state, result_cache, none?}} - end - ) - - input = if none?, do: %Axon.None{}, else: input - - {input, {state, result_cache}} - end - - init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> - {parent_template, {parent_params, result_cache, none?}} = - deep_map_reduce(parent_ids, {%{}, result_cache, false}, fn - parent_id, {params, result_cache, none?} -> - {parent_template, {params, result_cache}} = - call_init_cache( - parent_id, - template, - params, - cache, - result_cache, - fn_stacktrace, - keys - ) - - none? = none? or propagating_none?(parent_template) - {parent_template, {params, result_cache, none?}} - end) - - parent_template = if none?, do: %Axon.None{}, else: parent_template - - {parent_template, {parent_params, result_cache}} - end - - model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts, block_cache, model_state_meta} - end - defp recur_model_funs( %Axon.Node{ id: id, diff --git a/lib/axon/display.ex b/lib/axon/display.ex index 1e95e9c6..65241ea9 100644 --- a/lib/axon/display.ex +++ b/lib/axon/display.ex @@ -94,7 +94,8 @@ defmodule Axon.Display do defp do_axon_to_rows( %Axon.Node{ id: id, - op: :container, + op: structure, + op_name: :container, parent: [parents], name: name_fn }, @@ -105,7 +106,7 @@ defmodule Axon.Display do model_info ) do {input_names, {cache, op_counts, model_info}} = - deep_map_reduce(parents, {cache, op_counts, model_info}, fn + Enum.map_reduce(parents, {cache, op_counts, model_info}, fn parent_id, {cache, op_counts, model_info} -> {_, name, _shape, cache, op_counts, model_info} = axon_to_rows(parent_id, nodes, templates, cache, op_counts, model_info) @@ -119,7 +120,7 @@ defmodule Axon.Display do shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates) row = [ - "#{name} ( #{op_string} #{inspect(input_names)} )", + "#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )", "#{inspect({})}", "#{inspect(shape)}", render_options([]), @@ -311,27 +312,6 @@ defmodule Axon.Display do end end - defp recur_axon_to_edges( - %Axon.Node{id: id, op: :container, name: name_fn, parent: [parents]}, - nodes, - templates, - cache_counts_edgelist - ) do - {node_inputs, {cache, op_counts, edgelist}} = - deep_map_reduce(parents, cache_counts_edgelist, &axon_to_edges(&1, nodes, templates, &2)) - - name = name_fn.(:container, op_counts) - node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates) - to_node = %{axon: :axon, id: id, op: :container, name: name, shape: node_shape} - - new_edgelist = - deep_reduce(node_inputs, edgelist, fn from_node, acc -> - [{from_node, to_node} | acc] - end) - - {to_node, {cache, op_counts, new_edgelist}} - end - defp recur_axon_to_edges( %Axon.Node{id: id, op_name: op, name: name_fn, parent: parents}, nodes, diff --git a/test/axon_test.exs b/test/axon_test.exs index c77f14f2..87c4878e 100644 --- a/test/axon_test.exs +++ b/test/axon_test.exs @@ -874,8 +874,9 @@ defmodule AxonTest do out = Axon.input("input") |> Axon.dense(2) model = Axon.container({out, out}) - assert shape = Axon.get_output_shape(model, Nx.template({1, 1}, :f32)) - assert shape == {{1, 2}, {1, 2}} + assert {t1, t2} = Axon.get_output_shape(model, Nx.template({1, 1}, :f32)) + assert Nx.shape(t1) == {1, 2} + assert Nx.shape(t2) == {1, 2} end test "doesn't raise on none output" do