Skip to content

Commit

Permalink
Add option to not raise if output is none, resolves #538
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 10, 2024
1 parent 252da8c commit 19803a0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ defmodule Axon do
"""
@doc type: :graph
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
{init_fn, forward_fn} = build(axon, opts)
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])

out =
Nx.Defn.jit(
Expand Down
11 changes: 7 additions & 4 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ defmodule Axon.Compiler do
@doc false
def build(%Axon{output: id, nodes: nodes}, opts) do
debug? = Keyword.get(opts, :debug, false)
raise_on_none? = Keyword.get(opts, :raise_on_none, true)
mode = Keyword.get(opts, :mode, :inference)
seed = Keyword.get_lazy(opts, :seed, fn -> :erlang.system_time() end)
global_layer_options = Keyword.get(opts, :global_layer_options, [])
Expand Down Expand Up @@ -105,10 +106,12 @@ defmodule Axon.Compiler do
end

with %Axon.None{} <- result do
raise ArgumentError,
"the compiled model will always result in %Axon.None{}." <>
" This most likely means you specified optional output and " <>
" did not handle the case when it is missing"
if raise_on_none? do
raise ArgumentError,
"the compiled model will always result in %Axon.None{}." <>
" This most likely means you specified optional output and " <>
" did not handle the case when it is missing"
end
end

result
Expand Down
14 changes: 14 additions & 0 deletions test/axon_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,5 +1076,19 @@ defmodule AxonTest do
assert shape = Axon.get_output_shape(model, Nx.template({1, 1}, :f32))
assert shape == {{1, 2}, {1, 2}}
end

test "doesn't raise on none output" do
values = Axon.input("values")
mask = Axon.input("mask", optional: true)

model =
values
|> Axon.dense(10)
|> Axon.multiply(mask)
|> Axon.dense(1)
|> Axon.sigmoid()

assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)})
end
end
end

0 comments on commit 19803a0

Please sign in to comment.