Skip to content

Commit

Permalink
Refactor containers (#590)
Browse files Browse the repository at this point in the history
* Refactor containers

* Fix warning
  • Loading branch information
seanmor5 authored Jul 24, 2024
1 parent 9fce600 commit b93e87f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 226 deletions.
181 changes: 49 additions & 132 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ defmodule Axon do
alias __MODULE__, as: Axon
alias Axon.Parameter

import Axon.Shared

require Logger

@type t :: %__MODULE__{}
Expand Down Expand Up @@ -380,15 +382,6 @@ defmodule Axon do
}
end

defp split_inputs(:container, [inputs]) do
{inputs, cache} =
deep_map_reduce(inputs, %{}, fn %Axon{output: id, nodes: nodes}, cache ->
{id, Map.merge(nodes, cache)}
end)

{[inputs], [], [:layer], cache}
end

defp split_inputs(_op, inputs) do
Enum.reduce(inputs, {[], [], [], %{}}, fn
%Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} ->
Expand Down Expand Up @@ -704,62 +697,47 @@ defmodule Axon do
@doc type: :special
def container(container, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :meta])

layer(:container, [container], name: opts[:name], meta: opts[:meta], op_name: :container)
{structure_fn, nodes} = destructure(container)
layer(structure_fn, nodes, name: opts[:name], meta: opts[:meta], op_name: :container)
end

# TODO: This should not be duplicated
defp deep_new(%Nx.Tensor{} = x, fun), do: fun.(x)

defp deep_new(x, fun) when is_number(x), do: fun.(x)

defp deep_new(map, fun) do
{cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun))
cont
defp destructure(container) do
{structure, {nodes, _}} = recur_destructure(container, {[], 0})
fun = restructure(length(nodes) + 1, structure)
{fun, Enum.reverse(nodes)}
end

defp recur_traverse(item, :ok, fun) do
case item do
%Axon{} = t ->
{fun.(t), :ok}

%{axon: :axon} = t ->
{fun.(t), :ok}
defp recur_destructure(container, acc) do
Nx.Container.traverse(container, acc, fn value, {leaves, idx} ->
case value do
%Axon{} = leaf ->
{idx, {[leaf | leaves], idx + 1}}

container ->
{deep_new(container, fun), :ok}
end
container ->
recur_destructure(container, {leaves, idx})
end
end)
end

defp deep_merge(left, right, fun) do
case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
{merged, []} ->
merged
for i <- 0..128 do
args = Macro.generate_arguments(i, __MODULE__)

{_merged, _leftover} ->
raise ArgumentError,
"unable to merge arguments with incompatible" <>
" structure"
defp restructure(unquote(i), structure) do
fn unquote_splicing(args) ->
args_tuple = {unquote_splicing(args)}
{container, :ok} = recur_restructure(structure, args_tuple)
container
end
end
end

defp leaves(container) do
container
|> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
|> Enum.reverse()
end

defp recur_merge(left, [right | right_leaves], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_leaves}

{%Axon{} = left, %Axon{} = right} ->
{fun.(left, right), right_leaves}

{left, right} ->
{deep_merge(left, right, fun), right_leaves}
end
defp recur_restructure(structure, args_tuple) do
Nx.Container.traverse(structure, :ok, fn value, :ok ->
case value do
idx when is_integer(idx) -> {elem(args_tuple, idx), :ok}
container -> recur_restructure(container, args_tuple)
end
end)
end

@doc """
Expand Down Expand Up @@ -3644,35 +3622,31 @@ defmodule Axon do
end

@doc """
Returns a model's output shape from the given input
Returns a model's output template from the given input
template.
The output template gives you access to the output shape
and type of the given input graph.
"""
@doc type: :graph
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])

out =
inputs =
case inputs do
%Nx.Tensor{} = input -> Nx.to_template(input)
inputs when is_map(inputs) -> Map.new(inputs, fn {k, v} -> {k, Nx.to_template(v)} end)
end

fun =
Nx.Defn.jit(
fn inputs ->
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
end,
compiler: Axon.Defn
).(inputs)

safe_shape(out)
end

defp safe_shape(container_or_tensor) do
case container_or_tensor do
%Axon.None{} = none ->
none

%Nx.Tensor{} = tensor ->
Nx.shape(tensor)
)

container ->
deep_new(container, &Nx.shape/1)
end
deep_new(apply(fun, [inputs]), &Nx.to_template/1)
end

@doc """
Expand Down Expand Up @@ -3850,74 +3824,17 @@ defmodule Axon do
if MapSet.member?(visited, id) do
{acc, visited}
else
%{op: op, parent: parents} = parent = nodes[id]
%{parent: parents} = parent = nodes[id]

{acc, visited} =
case op do
:container ->
[container] = parents

deep_reduce(container, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)

_ ->
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)
end
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
traverse_nodes(pid, nodes, acc, visited)
end)

{[parent | acc], MapSet.put(visited, id)}
end
end

# TODO: Do not duplicate
defp deep_reduce(item, acc, fun) when is_integer(item) do
fun.(item, acc)
end

defp deep_reduce(map, acc, fun) do
Nx.Container.reduce(map, acc, &recur_deep_reduce(&1, &2, fun))
end

defp recur_deep_reduce(value, acc, fun) do
case value do
%Axon{} = val ->
fun.(val, acc)

%Nx.Tensor{} = val ->
fun.(val, acc)

%{axon: :axon} = val ->
fun.(val, acc)

val when is_integer(val) ->
fun.(val, acc)

val ->
deep_reduce(val, acc, fun)
end
end

defp deep_map_reduce(leaf, acc, fun) when is_integer(leaf), do: fun.(leaf, acc)

defp deep_map_reduce(container, acc, fun) do
Nx.Container.traverse(container, acc, &recur_deep_map_reduce(&1, &2, fun))
end

defp recur_deep_map_reduce(leaf, acc, fun) do
case leaf do
%Axon{} = leaf ->
fun.(leaf, acc)

%Nx.Tensor{} = leaf ->
fun.(leaf, acc)

container ->
deep_map_reduce(container, acc, fun)
end
end

@doc """
Pops the top node off of the graph.
Expand Down
67 changes: 0 additions & 67 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ defmodule Axon.Compiler do
@moduledoc false
require Logger

import Axon.Shared
alias Axon.StatefulOutput

## Init JIT Compilation
Expand Down Expand Up @@ -549,72 +548,6 @@ defmodule Axon.Compiler do
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{id: id, op: :container, parent: [parents]},
nodes,
cache_and_counts,
config
) do
{parent_ids, {cache, op_counts, block_cache, model_state_meta}} =
deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config))

op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end)

predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
{input, {state, result_cache, none?}} =
deep_map_reduce(
parent_ids,
{state, result_cache, false},
fn parent_id, {state, result_cache, none?} ->
{input, {state, result_cache}} =
call_predict_cache(
parent_id,
params,
inputs,
state,
cache,
result_cache,
fn_stacktrace
)

none? = none? or propagating_none?(input)
{input, {state, result_cache, none?}}
end
)

input = if none?, do: %Axon.None{}, else: input

{input, {state, result_cache}}
end

init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
{parent_template, {parent_params, result_cache, none?}} =
deep_map_reduce(parent_ids, {%{}, result_cache, false}, fn
parent_id, {params, result_cache, none?} ->
{parent_template, {params, result_cache}} =
call_init_cache(
parent_id,
template,
params,
cache,
result_cache,
fn_stacktrace,
keys
)

none? = none? or propagating_none?(parent_template)
{parent_template, {params, result_cache, none?}}
end)

parent_template = if none?, do: %Axon.None{}, else: parent_template

{parent_template, {parent_params, result_cache}}
end

model_funs = %{predict: predict_fun, init: init_fun}
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{
id: id,
Expand Down
29 changes: 4 additions & 25 deletions lib/axon/display.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ defmodule Axon.Display do
Module for rendering various visual representations of Axon models.
"""

import Axon.Shared
alias Axon.Parameter

@compile {:no_warn_undefined, TableRex.Table}
Expand Down Expand Up @@ -94,7 +93,8 @@ defmodule Axon.Display do
defp do_axon_to_rows(
%Axon.Node{
id: id,
op: :container,
op: structure,
op_name: :container,
parent: [parents],
name: name_fn
},
Expand All @@ -105,7 +105,7 @@ defmodule Axon.Display do
model_info
) do
{input_names, {cache, op_counts, model_info}} =
deep_map_reduce(parents, {cache, op_counts, model_info}, fn
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
parent_id, {cache, op_counts, model_info} ->
{_, name, _shape, cache, op_counts, model_info} =
axon_to_rows(parent_id, nodes, templates, cache, op_counts, model_info)
Expand All @@ -119,7 +119,7 @@ defmodule Axon.Display do
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)

row = [
"#{name} ( #{op_string} #{inspect(input_names)} )",
"#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )",
"#{inspect({})}",
"#{inspect(shape)}",
render_options([]),
Expand Down Expand Up @@ -311,27 +311,6 @@ defmodule Axon.Display do
end
end

defp recur_axon_to_edges(
%Axon.Node{id: id, op: :container, name: name_fn, parent: [parents]},
nodes,
templates,
cache_counts_edgelist
) do
{node_inputs, {cache, op_counts, edgelist}} =
deep_map_reduce(parents, cache_counts_edgelist, &axon_to_edges(&1, nodes, templates, &2))

name = name_fn.(:container, op_counts)
node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
to_node = %{axon: :axon, id: id, op: :container, name: name, shape: node_shape}

new_edgelist =
deep_reduce(node_inputs, edgelist, fn from_node, acc ->
[{from_node, to_node} | acc]
end)

{to_node, {cache, op_counts, new_edgelist}}
end

defp recur_axon_to_edges(
%Axon.Node{id: id, op_name: op, name: name_fn, parent: parents},
nodes,
Expand Down
Loading

0 comments on commit b93e87f

Please sign in to comment.