Skip to content

Commit

Permalink
Add metadata to layers (#535)
Browse files Browse the repository at this point in the history
* Add metadata to layers

* Apply suggestions from code review

* Update lib/axon/compiler.ex

---------

Co-authored-by: José Valim <[email protected]>
  • Loading branch information
seanmor5 and josevalim authored Oct 10, 2023
1 parent 57c75e5 commit 235a030
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
32 changes: 25 additions & 7 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ defmodule Axon.Compiler do
&5,
&6,
op,
op_name,
parent_ids,
name,
args,
Expand Down Expand Up @@ -788,6 +789,7 @@ defmodule Axon.Compiler do
result_cache,
fn_stacktrace,
op,
op_name,
parent_ids,
name,
args,
Expand Down Expand Up @@ -874,7 +876,7 @@ defmodule Axon.Compiler do
# in Axon.Layers. The implication of this is that every function which
# can be invoked as a layer must have a definition in Axon.Layers even
# if there is a distinction (e.g. with activations)
result = apply_layer(name, op, args, layer_stacktrace, fn_stacktrace)
result = apply_layer(name, op, args, layer_stacktrace, fn_stacktrace, op_name)

result =
case result do
Expand Down Expand Up @@ -912,14 +914,30 @@ defmodule Axon.Compiler do
end
end

defp apply_layer(name, op, args, layer_stacktrace, fn_stacktrace) do
defp apply_layer(name, op, args, layer_stacktrace, fn_stacktrace, op_name) do
try do
case op do
op when is_function(op) ->
apply(op, args)
result =
case op do
op when is_function(op) ->
apply(op, args)

op when is_atom(op) ->
apply(Axon.Layers, op, args)
end

case result do
out when is_tuple(out) ->
out

%Axon.None{} = out ->
out

%Axon.StatefulOutput{output: out} = stateful ->
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
%{stateful | output: out}

op when is_atom(op) ->
apply(Axon.Layers, op, args)
out ->
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
end
rescue
exception ->
Expand Down
15 changes: 15 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5758,4 +5758,19 @@ defmodule CompilerTest do
assert predict_fn1 == predict_fn2
end
end

describe "metadata" do
test "axon compiler attaches layer name as metadata to subgraphs" do
model = Axon.input("input", shape: {nil, 784}) |> Axon.dense(128)

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 784}, :f32), %{})
input = Nx.broadcast(0.0, {1, 784})

expr_fn = Nx.Defn.jit(predict_fn, compiler: Axon.Defn)
expr = expr_fn.(params, input)

assert %{data: %{op: :metadata, args: [_tensor, %{axon_layer: :dense}]}} = expr
end
end
end

0 comments on commit 235a030

Please sign in to comment.