Skip to content

Commit

Permalink
Do not cast integers in in Axon.MixedPrecision.cast/2 (#562)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Mar 5, 2024
1 parent 2a2d165 commit acdc002
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
33 changes: 10 additions & 23 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ defmodule Axon.Compiler do
end

defp recur_model_funs(
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: %{output: output}},
%Axon.Node{id: id, op: :constant, opts: [value: tensor], policy: policy},
_nodes,
{cache, op_counts, block_cache},
_
Expand All @@ -361,7 +361,7 @@ defmodule Axon.Compiler do
tensor = Nx.backend_copy(tensor, Nx.BinaryBackend)

predict_fun = fn _params, _inputs, state, _cache, result_cache, _fn_stacktrace ->
out = safe_as_type(tensor, output)
out = safe_policy_cast(tensor, policy, :output)
{out, {state, result_cache}}
end

Expand Down Expand Up @@ -841,7 +841,7 @@ defmodule Axon.Compiler do
name,
args,
opts,
%{output: output, compute: compute},
policy,
layer_params,
hooks,
mode,
Expand Down Expand Up @@ -870,7 +870,7 @@ defmodule Axon.Compiler do

layer_input =
layer_input
|> safe_as_type(compute)
|> safe_policy_cast(policy, :compute)
|> apply_hooks(:pre_forward, mode, hooks)

{layer_input, {state, result_cache, none?}}
Expand All @@ -889,7 +889,7 @@ defmodule Axon.Compiler do

cond do
param != nil ->
safe_as_type(maybe_freeze(param, frz), compute)
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)

true ->
raise ArgumentError,
Expand Down Expand Up @@ -939,7 +939,7 @@ defmodule Axon.Compiler do
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> safe_as_type(output)
|> safe_policy_cast(policy, :output)

new_state = Map.put(state, name, out_state)
{new_out, new_state}
Expand All @@ -949,7 +949,7 @@ defmodule Axon.Compiler do
out
|> apply_hooks(:forward, mode, hooks)
|> apply_hooks(:backward, mode, hooks)
|> safe_as_type(output)
|> safe_policy_cast(policy, :output)

{new_out, state}
end
Expand Down Expand Up @@ -1130,26 +1130,13 @@ defmodule Axon.Compiler do
end)
end

defp safe_as_type(container_or_tensor, type) do
defp safe_policy_cast(container_or_tensor, policy, variable_type) do
case container_or_tensor do
%Axon.None{} = none ->
none

%Nx.Tensor{} = tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end

container ->
deep_new(container, fn tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end
end)
container_or_tensor ->
Axon.MixedPrecision.cast(policy, container_or_tensor, variable_type)
end
end

Expand Down
20 changes: 18 additions & 2 deletions lib/axon/mixed_precision.ex
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,26 @@ defmodule Axon.MixedPrecision do
iex> value = Axon.MixedPrecision.cast(policy, value, :output)
iex> Nx.type(value)
{:bf, 16}
Note that integers are never promoted to floats:
iex> policy = Axon.MixedPrecision.create_policy(output: {:f, 16})
iex> value = Nx.tensor([1, 2, 3], type: :s64)
iex> value = Axon.MixedPrecision.cast(policy, value, :params)
iex> Nx.type(value)
{:s, 64}
"""
def cast(%Policy{} = policy, tensor_or_container, variable_type)
when variable_type in [:compute, :params, :output] do
type = get_in(policy, [Access.key!(variable_type)])
deep_new(tensor_or_container, fn x -> Nx.as_type(x, type) end)
type = Map.fetch!(policy, variable_type)

deep_new(tensor_or_container, fn tensor ->
if not Nx.Type.integer?(Nx.type(tensor)) and not Nx.Type.integer?(type) do
Nx.as_type(tensor, type)
else
tensor
end
end)
end
end

0 comments on commit acdc002

Please sign in to comment.