Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: experimental sharding backend #1544

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
25788c0
feat: add basic backend layout
polvalente Sep 25, 2024
6585353
proof of concept solution
polvalente Sep 25, 2024
54e7abf
feat: add working POC for sharding
polvalente Sep 26, 2024
5fe92ca
refactor: use only sharding compiler
polvalente Sep 26, 2024
09b844a
refactor: move things into __compile__
polvalente Sep 26, 2024
f6bd4ed
feat: working EXLA example
polvalente Sep 26, 2024
4c8200f
wip: initial work on dot (doesn't work)
polvalente Sep 26, 2024
2f9370a
wip
polvalente Sep 26, 2024
918118e
wip: refactor input sharding calculation
polvalente Sep 28, 2024
bd06388
wip: refactor to support broadcasts
polvalente Sep 28, 2024
5d7b106
wip: rework sharding representation (each slice is a shard)
polvalente Sep 29, 2024
1f7dfb3
feat: deal with broadcasting and re-slicing
polvalente Sep 29, 2024
243eee2
refactor: build parents tree into each shard
polvalente Oct 10, 2024
4bb2d97
feat: support implicit broadcasting
polvalente Oct 10, 2024
87f1c35
chore: remove unused var'
polvalente Oct 10, 2024
bca3943
feat: support dot product without contraction sharding
polvalente Oct 10, 2024
73a3553
feat: support constants
polvalente Oct 10, 2024
52fc860
add :tensor op to example function
polvalente Oct 10, 2024
7706678
feat: support squeeze
polvalente Oct 10, 2024
0ddb69a
chore: remove empty file
polvalente Oct 10, 2024
489d655
refactor: remove TensorSharding module
polvalente Oct 11, 2024
42b5ca6
chore: add stubs for missing callbacks
polvalente Oct 11, 2024
59034c1
test: add tests
polvalente Oct 11, 2024
d6da7a8
fix: transpose axis
polvalente Oct 11, 2024
68565bc
chore: remove example .exs files
polvalente Oct 11, 2024
fb86713
Update nx/lib/nx/defn/sharding_compiler.ex
polvalente Oct 14, 2024
0c111df
chore: format
polvalente Oct 14, 2024
5c5c881
chore: remove __stream__
polvalente Oct 14, 2024
5e05bbb
Merge remote-tracking branch 'origin/main' into pv-feat/experimental-…
polvalente Oct 15, 2024
6eb6fba
feat: add graph splitter for all-gather/all-reduce operations (#1545)
polvalente Oct 17, 2024
dce2c60
feat: add shard execution workflow (#1557)
polvalente Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nx/config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ import Config
# true inside Nx.
config :nx, :verify_grad, true
config :nx, :verify_binary_size, true

# If set to true, shards and sharding stages will be
# inspected with their debug ids alongside their unique ref ids
config :nx, :debug_shards, true
1 change: 1 addition & 0 deletions nx/lib/nx/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ defmodule Nx.Application do

def start(_type, _args) do
children = [
Nx.Defn.ShardingCompiler.ShardRegistry,
%{id: Nx.Serving.PG, start: {:pg, :start_link, [Nx.Serving.PG]}},
{Nx.HiddenServing, Nx.Serving.PG}
]
Expand Down
179 changes: 179 additions & 0 deletions nx/lib/nx/defn/sharding_compiler.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
defmodule Nx.Defn.ShardingCompiler do
alias Nx.Tensor, as: T
alias Nx.Defn.Expr

alias Nx.Defn.ShardingCompiler.Shard

alias Nx.Defn.ShardingCompiler.Passes.ShardPropagation
# alias Nx.Defn.ShardingCompiler.Passes.GraphSplitter

@behaviour Nx.Defn.Compiler

@impl true
def __jit__(key, vars, fun, args, opts) do
opts =
Keyword.validate!(opts, [
:sharding_config,
sharding_compiler: Nx.Defn.Evaluator,
sharding_compiler_options: []
])

[args] = args

{%T{
type: type,
data: %ShardPropagation{
shards: output_shards
}
}, parameter_ids_to_index,
shape} =
propagate_shards(vars, fun, opts[:sharding_config] || [])

data_sections =
output_shards |> Enum.sort_by(fn {axis, _} -> axis end) |> cartesian_product()

# Find the parents for each data section
# Group by inputs
# For each input, sort the shards by axis
# For each axis, find the minimum start and the maximum end (we need to test for slicing inside the code as well)
# it might be the case where an axis is not present in the mapping. This means we need the full axis.

result =
for section <- data_sections do
shards_by_input_id =
section
|> Enum.flat_map(fn {_axis, shard} ->
get_root_parents(shard)
end)
|> Enum.group_by(fn shard -> shard.input_id end)

inputs_by_index =
parameter_ids_to_index
|> Enum.sort_by(fn {_id, idx} -> idx end)
|> Enum.map(fn {id, idx} -> {id, Enum.fetch!(args, idx)} end)

sliced_inputs =
for {input_id, input_fn} <- inputs_by_index do
input = input_fn.()
shards = shards_by_input_id[input_id]
shards_by_axis = Enum.group_by(shards, & &1.axis)

{_, _, starts_reverse, lengths_reverse} =
Enum.reduce(Tuple.to_list(input.shape), {shards_by_axis, 0, [], []}, fn
axis_size, {shards_by_axis, axis, starts, lengths} ->
{shards, shards_by_axis} = Map.pop(shards_by_axis, axis)

{starts, lengths} =
if shards do
min_start = Enum.min(Enum.map(shards, & &1.start))
max_end = Enum.max(Enum.map(shards, &(&1.start + &1.length - 1)))

starts = [min_start | starts]
lengths = [max_end - min_start + 1 | lengths]
{starts, lengths}
else
starts = [0 | starts]
lengths = [axis_size | lengths]
{starts, lengths}
end

{shards_by_axis, axis + 1, starts, lengths}
end)

starts = Enum.reverse(starts_reverse)
lengths = Enum.reverse(lengths_reverse)

Nx.slice(input, starts, lengths)
end

{out_starts, []} =
Enum.map_reduce(0..(tuple_size(shape) - 1)//1, section, fn
axis, [{axis, shard} | shards] ->
{shard.start, shards}

_axis, shards ->
{0, shards}
end)

caster_fn = fn result, acc ->
Nx.put_slice(acc, out_starts, result)
end

sharding_compiler = opts[:sharding_compiler]
sharding_compiler_options = opts[:sharding_compiler_options]

vars =
Enum.with_index(sliced_inputs, fn arg, idx ->
arg
|> Expr.parameter(:root, idx)
end)

compiled_fun =
sharding_compiler.__compile__({key, section}, vars, fun, sharding_compiler_options)

shard_fn = fn [args] ->
[res] =
compiled_fun.([
Enum.map(Tuple.to_list(args), fn arg ->
fn -> arg end
end)
])

res
end

{[List.to_tuple(sliced_inputs)], shard_fn, caster_fn}
end

output_holder = Nx.iota(shape, type: type)
[{output_holder, result}]
end

defp cartesian_product([{axis, first} | rest]) do
for x <- first, y <- cartesian_product(rest), do: [{axis, x} | y]
end

defp cartesian_product([]), do: [[]]

@impl true
def __compile__(_key, _vars, _fun, _opts) do
raise "Not implemented yet"
end

def propagate_shards(vars, fun, sharding_config) do
expr = fun.(vars)

tensor_shardings =
sharding_config
|> Enum.zip_with(vars, fn config, var ->
Shard.from_config(var, config)
end)
|> Enum.with_index(fn x, idx -> {idx, x} end)
|> Map.new()

{container, _cache, state} = ShardPropagation.traverse(expr, tensor_shardings)

{container, state.parameter_ids_to_index, expr.shape}
end

@impl true
def __partitions_options__(_keyword) do
raise "__partitions_options__ not supported"
end

@impl true
def __to_backend__(_keyword) do
raise "__to_backend__ not supported"
end

def init(opts), do: opts

defp get_root_parents(shard, acc \\ [])

defp get_root_parents(%Shard{parents: []} = shard, acc), do: List.flatten([shard | acc])

defp get_root_parents(%Shard{parents: parents}, acc) do
Enum.reduce(parents, acc, &get_root_parents/2)
|> List.flatten()
end
end
Loading
Loading