Skip to content

Commit

Permalink
Use model state everywhere as default
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jul 30, 2024
1 parent 8e0a6d9 commit 054eb4c
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3943,9 +3943,10 @@ defmodule Axon do
It accepts the same options as `build/2`.
"""
@doc type: :model
def compile(model, template, init_params \\ %{}, opts \\ []) when is_list(opts) do
def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ [])
when is_list(opts) do
{init_fn, predict_fn} = build(model, opts)
init_params = Nx.Defn.jit_apply(init_fn, [template, init_params], opts)
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts)
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
{init_params, predict_compiled_fn}
end
Expand Down Expand Up @@ -3976,7 +3977,7 @@ defmodule Axon do
@doc type: :debug
def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do
{init_fn, _} = build(model, opts)
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params)
Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, Axon.ModelState.new(params))
end

@doc """
Expand All @@ -4001,7 +4002,7 @@ defmodule Axon do
@doc type: :debug
def trace_forward(model, inputs, params, opts \\ []) when is_list(opts) do
{_, forward_fun} = build(model, opts)
Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(params, inputs)
Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)
end

@doc """
Expand Down Expand Up @@ -4034,17 +4035,19 @@ defmodule Axon do
end)
end

%{prediction: outputs} = Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(params, inputs)
%{prediction: outputs} =
Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs)

inputs = [params, inputs, outputs]

apply(Nx.Defn.jit(backward_fn, compiler: Axon.Defn), inputs)
end

@doc false
@deprecated "Use Axon.build/2 instead"
def init(model, template, params \\ %{}, opts \\ []) when is_list(opts) do
def init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do
{init_fn, _predict_fn} = build(model, opts)
init_fn.(template, params)
init_fn.(template, Axon.ModelState.new(params))
end

@doc """
Expand All @@ -4069,7 +4072,7 @@ defmodule Axon do
@doc type: :model
def predict(%Axon{} = model, params, input, opts \\ []) when is_list(opts) do
{_init_fn, predict_fn} = build(model, opts)
predict_fn.(params, input)
predict_fn.(Axon.ModelState.new(params), input)
end

## Inspection
Expand Down

0 comments on commit 054eb4c

Please sign in to comment.