Skip to content

Commit

Permalink
Use templates as parameters (#588)
Browse files Browse the repository at this point in the history
* Use templates as parameters

* Uncomment deps
  • Loading branch information
seanmor5 authored Jul 24, 2024
1 parent 8cee5a9 commit c4d33e5
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 223 deletions.
281 changes: 209 additions & 72 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,52 @@ defmodule Axon do
@doc """
Trainable Axon parameter used to create custom layers.
Parameters are specified in usages of `Axon.layer` and will be
automatically initialized and used in subsequent applications of
Axon models.
You must specify a parameter "template" which can be a static template
tensor or a function which takes model input templates and returns a
template. It's most common to use functions because most parameters'
shapes rely on input shape information.
"""
@doc type: :special
def parameter(name, template, opts \\ [])

def parameter(name, %Nx.Tensor{} = template, opts) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
kind = opts[:kind] || :parameter

template = Nx.to_template(template)

%Axon.Parameter{
name: name,
template: template,
initializer: initializer,
kind: kind,
# Legacy
type: Nx.type(template),
shape: Nx.shape(template)
}
end

def parameter(name, function, opts) when is_function(function) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
kind = opts[:kind] || :parameter

%Axon.Parameter{
name: name,
template: function,
initializer: initializer,
kind: kind
}
end

@doc """
Trainable Axon parameter used to create custom layers.
Parameters are specified in usages of `Axon.layer` and will
be automatically initialized and used in subsequent applications
of Axon models.
Expand All @@ -421,36 +467,35 @@ defmodule Axon do
@doc type: :special
def param(name, shape, opts \\ [])

def param(name, {:map, [_ | _] = inner_params}, opts) do
maybe_warn_on_param_opts(opts)
def param(name, shape, opts) when is_binary(name) and is_tuple(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
{type, opts} = Keyword.pop(opts, :type, {:f, 32})

%Axon.Parameter{
name: name,
type: :map,
children: inner_params
}
template = Nx.template(shape, type)
parameter(name, template, opts)
end

def param(name, shape, opts) when is_binary(name) and (is_tuple(shape) or is_function(shape)) do
def param(name, shape, opts) when is_binary(name) and is_function(shape) do
opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter)
initializer = validate_initializer!(opts[:initializer])
type = opts[:type] || {:f, 32}
kind = opts[:kind] || :parameter
{type, opts} = Keyword.pop(opts, :type, {:f, 32})

%Axon.Parameter{
name: name,
shape: shape,
type: type,
initializer: initializer,
kind: kind
}
{:arity, arity} = Function.info(shape, :arity)

template =
shape_fun(arity, fn templates ->
shapes = Enum.map(List.wrap(templates), &Nx.shape/1)
out_shape = apply(shape, shapes)
Nx.template(out_shape, type)
end)

parameter(name, template, opts)
end

defp maybe_warn_on_param_opts(opts) do
if :initializer in opts or :type in opts do
Logger.warning(
"Passing options to a composite parameter has no effect. Pass them to inner parameters instead"
)
for i <- 0..128 do
args = Macro.generate_arguments(i, __MODULE__)

defp shape_fun(unquote(i), callback) do
fn unquote_splicing(args) -> callback.(unquote(args)) end
end
end

Expand Down Expand Up @@ -2583,25 +2628,63 @@ defmodule Axon do
activation = opts[:activation]
gate = opts[:gate]
unroll = opts[:unroll]

kernel_initializer = opts[:kernel_initializer]

input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end
input_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

wii = param("wii", input_kernel_shape, initializer: kernel_initializer)
wif = param("wif", input_kernel_shape, initializer: kernel_initializer)
wig = param("wig", input_kernel_shape, initializer: kernel_initializer)
wio = param("wio", input_kernel_shape, initializer: kernel_initializer)
hidden_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

bias_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :lstm)
Nx.template(shape, :f32)
end

initializer = fn prefix, init ->
fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 4)

init =
if is_atom(init) do
apply(Axon.Initializers, init, [])
else
init
end

whi = param("whi", hidden_kernel_shape, initializer: kernel_initializer)
whf = param("whf", hidden_kernel_shape, initializer: kernel_initializer)
whg = param("whg", hidden_kernel_shape, initializer: kernel_initializer)
who = param("who", hidden_kernel_shape, initializer: kernel_initializer)
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"#{prefix}i" => fun.(split_key[0]),
"#{prefix}f" => fun.(split_key[1]),
"#{prefix}g" => fun.(split_key[2]),
"#{prefix}o" => fun.(split_key[3])
}
end
end

# Parameters
input_kernel = param("input_kernel", {:map, [wii, wif, wig, wio]})
hidden_kernel = param("hidden_kernel", {:map, [whi, whf, whg, who]})
input_kernel =
parameter("input_kernel", input_kernel_template,
initializer: initializer.("wi", kernel_initializer)
)

hidden_kernel =
parameter("hidden_kernel", hidden_kernel_template,
initializer: initializer.("wh", kernel_initializer)
)

hidden_state_name =
case opts[:name] do
Expand All @@ -2620,12 +2703,7 @@ defmodule Axon do
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]

bi = param("bi", bias_shape, initializer: bias_initializer)
bf = param("bf", bias_shape, initializer: bias_initializer)
bg = param("bg", bias_shape, initializer: bias_initializer)
bo = param("bo", bias_shape, initializer: bias_initializer)

bias = param("bias", {:map, [bi, bf, bg, bo]})
bias = parameter("bias", bias_template, initializer: initializer.("b", bias_initializer))

{[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias], :lstm}
else
Expand Down Expand Up @@ -2790,22 +2868,58 @@ defmodule Axon do
gate = opts[:gate]
unroll = opts[:unroll]

input_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_input_kernel(inp, units, :gru) end
hidden_kernel_shape = fn inp, _, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :gru) end
bias_shape = fn inp, _, _ -> Axon.Shape.rnn_bias(inp, units, :gru) end
input_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_input_kernel(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

hidden_kernel_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_hidden_kernel(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

kernel_initializer = opts[:kernel_initializer]
bias_template = fn inp, _, _ ->
shape = Axon.Shape.rnn_bias(Nx.shape(inp), units, :gru)
Nx.template(shape, :f32)
end

wir = param("wir", input_kernel_shape, initializer: kernel_initializer)
wiz = param("wiz", input_kernel_shape, initializer: kernel_initializer)
win = param("win", input_kernel_shape, initializer: kernel_initializer)
initializer = fn prefix, init ->
fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 3)

whr = param("whr", hidden_kernel_shape, initializer: kernel_initializer)
whz = param("whz", hidden_kernel_shape, initializer: kernel_initializer)
whn = param("whn", hidden_kernel_shape, initializer: kernel_initializer)
init =
if is_atom(init) do
apply(Axon.Initializers, init, [])
else
init
end

input_kernel = param("input_kernel", {:map, [wir, wiz, win]})
hidden_kernel = param("hidden_kernel", {:map, [whr, whz, whn]})
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"#{prefix}r" => fun.(split_key[0]),
"#{prefix}z" => fun.(split_key[1]),
"#{prefix}n" => fun.(split_key[2])
}
end
end

input_kernel =
parameter("input_kernel", input_kernel_template,
initializer: initializer.("wi", opts[:kernel_initializer])
)

hidden_kernel =
parameter("hidden_kernel", hidden_kernel_template,
initializer: initializer.("wh", opts[:kernel_initializer])
)

hidden_state_name =
case opts[:name] do
Expand All @@ -2822,14 +2936,34 @@ defmodule Axon do

inputs =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
bias_initializer = fn shape, type, key ->
split_key = Nx.Random.split(key, parts: 4)

init =
if is_atom(opts[:bias_initializer]) do
apply(Axon.Initializers, opts[:bias_initializer], [])
else
opts[:bias_initializer]
end

br = param("br", bias_shape, initializer: bias_initializer)
bz = param("bz", bias_shape, initializer: bias_initializer)
bin = param("bin", bias_shape, initializer: bias_initializer)
bhn = param("bhn", bias_shape, initializer: bias_initializer)
fun =
case init do
init when is_function(init, 2) ->
fn _ -> init.(shape, type) end

init when is_function(init, 3) ->
fn key -> init.(shape, type, key) end
end

%{
"br" => fun.(split_key[0]),
"bz" => fun.(split_key[1]),
"bin" => fun.(split_key[2]),
"bhn" => fun.(split_key[3])
}
end

bias = param("bias", {:map, [br, bz, bin, bhn]})
bias = parameter("bias", bias_template, initializer: bias_initializer)

[x, hidden_state, opts[:mask], input_kernel, hidden_kernel, bias]
else
Expand Down Expand Up @@ -2983,23 +3117,26 @@ defmodule Axon do
unroll = opts[:unroll]
kernel_initializer = opts[:kernel_initializer]

hidden_kernel_shape = fn _, {inp, _}, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
hidden_kernel_template = fn _, {inp, _}, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

input_kernel_shape = fn inp, _, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
input_kernel_template = fn inp, _, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_kernel(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

bias_shape = fn inp, _, _ ->
shape = Tuple.delete_at(inp, 1)
Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
bias_template = fn inp, _, _ ->
shape = Tuple.delete_at(Nx.shape(inp), 1)
shape = Axon.Shape.conv_bias(shape, 4 * units, kernel_size, :first, 1)
Nx.template(shape, :f32)
end

wi = param("input_kernel", input_kernel_shape, initializer: kernel_initializer)
wh = param("hidden_kernel", hidden_kernel_shape, initializer: kernel_initializer)
wi = parameter("input_kernel", input_kernel_template, initializer: kernel_initializer)
wh = parameter("hidden_kernel", hidden_kernel_template, initializer: kernel_initializer)

hidden_state_name =
case opts[:name] do
Expand All @@ -3017,7 +3154,7 @@ defmodule Axon do
{inputs, op} =
if opts[:use_bias] do
bias_initializer = opts[:bias_initializer]
b = param("bias", bias_shape, initializer: bias_initializer)
b = parameter("bias", bias_template, initializer: bias_initializer)
{[x, hidden_state, opts[:mask], wi, wh, b], :conv_lstm}
else
{[x, hidden_state, opts[:mask], wi, wh], :conv_lstm}
Expand Down
Loading

0 comments on commit c4d33e5

Please sign in to comment.