diff --git a/lib/axon.ex b/lib/axon.ex index 4b2c1935..fa73645a 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3943,9 +3943,10 @@ defmodule Axon do It accepts the same options as `build/2`. """ @doc type: :model - def compile(model, template, init_params \\ %{}, opts \\ []) when is_list(opts) do + def compile(model, template, init_params \\ Axon.ModelState.empty(), opts \\ []) + when is_list(opts) do {init_fn, predict_fn} = build(model, opts) - init_params = Nx.Defn.jit_apply(init_fn, [template, init_params], opts) + init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(init_params)], opts) predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts) {init_params, predict_compiled_fn} end @@ -3976,7 +3977,7 @@ defmodule Axon do @doc type: :debug def trace_init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) do {init_fn, _} = build(model, opts) - Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, params) + Nx.Defn.jit(init_fn, compiler: Axon.Defn).(template, Axon.ModelState.new(params)) end @doc """ @@ -4001,7 +4002,7 @@ defmodule Axon do @doc type: :debug def trace_forward(model, inputs, params, opts \\ []) when is_list(opts) do {_, forward_fun} = build(model, opts) - Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(params, inputs) + Nx.Defn.jit(forward_fun, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs) end @doc """ @@ -4034,7 +4035,9 @@ defmodule Axon do end) end - %{prediction: outputs} = Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(params, inputs) + %{prediction: outputs} = + Nx.Defn.jit(forward_fn, compiler: Axon.Defn).(Axon.ModelState.new(params), inputs) + inputs = [params, inputs, outputs] apply(Nx.Defn.jit(backward_fn, compiler: Axon.Defn), inputs) @@ -4042,9 +4045,9 @@ defmodule Axon do @doc false @deprecated "Use Axon.build/2 instead" - def init(model, template, params \\ %{}, opts \\ []) when is_list(opts) do + def init(model, template, params \\ Axon.ModelState.empty(), opts \\ []) when is_list(opts) do {init_fn, _predict_fn} = build(model, opts) - init_fn.(template, params) + init_fn.(template, Axon.ModelState.new(params)) end @doc """ @@ -4069,7 +4072,7 @@ defmodule Axon do @doc type: :model def predict(%Axon{} = model, params, input, opts \\ []) when is_list(opts) do {_init_fn, predict_fn} = build(model, opts) - predict_fn.(params, input) + predict_fn.(Axon.ModelState.new(params), input) end ## Inspection