Skip to content

Commit

Permalink
Add sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Nov 21, 2024
1 parent 4e022cd commit ac83c88
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 146 deletions.
3 changes: 2 additions & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public static class CSourceBuiltn
public const string KernelHeader = @"#pragma once
#include <nncase/ntt/ntt.h>
using namespace nncase::ntt;
using namespace nncase::ntt::dist_policy;
using namespace nncase::ntt::distributed;
using namespace nncase::ntt::distributed::dist_policy;
";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ protected override CSymbol VisitBuffer(TIR.Buffer expr)
var type = VisitEntry.Parameters.AsValueEnumerable().Contains(expr) || expr.MemSpan.Location == MemoryLocation.Rdata || expr.MemSpan.Start is TensorConst
? (expr.DistributedType == null
? $"tensor_view<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.Dimensions)}, {KernelUtility.StridesToC(expr.Strides)}> "
: $"dist_tensor_view<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.DistributedType.TensorType.Shape)}, {KernelUtility.NdSBPToC(expr.DistributedType.NdSBP)}, topology::thread, {KernelUtility.StridesToC(expr.Strides)}> ")
: $"sharded_tensor_view<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.DistributedType.TensorType.Shape)}, {KernelUtility.DistributedToC(expr.DistributedType)}, {KernelUtility.StridesToC(expr.Strides)}> ")
: $"tensor<{expr.ElemType.ToC()}, {KernelUtility.DimensionsToC(expr.Dimensions)}> ";

symbol = new(type, expr.Name);
Expand Down
39 changes: 24 additions & 15 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,37 @@ public static string StridesToC(ReadOnlySpan<Expr> dimensions)
return sb.ToString();
}

public static string NdSBPToC(IRArray<SBP> ndSBP)
public static string DistributedToC(DistributedType distributedType)
{
var sb = new StringBuilder("dist<");
for (int i = 0; i < ndSBP.Count; i++)
var placement = distributedType.Placement;
var ndSBP = distributedType.NdSBP;

var sb = new StringBuilder("sharding<mesh<topology::thread, ");
for (int i = 0; i < placement.Rank; i++)
{
var value = ndSBP[i];
if (value is SBPBroadCast)
{
sb.Append('B');
}
else if (value is SBPPartialSum)
var value = placement.Hierarchy[i];
sb.Append($"{value}");
if (i != placement.Rank - 1)
{
sb.Append('P');
sb.Append(", ");
}
else if (value is SBPSplit split)
}

var nonAxisPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P" : "B";
sb.Append('>');

for (int axis = 0; axis < distributedType.TensorType.Shape.Rank; axis++)
{
var value = from sbp in ndSBP.Select((x, i) => (x, i))
where sbp.x is SBPSplit split && split.Axis == axis
select sbp.i;
if (value.Any())
{
sb.Append($"S<{split.Axis}>");
sb.Append($", S<{string.Join(", ", value)}>");
}

if (i != ndSBP.Count - 1)
else
{
sb.Append(", ");
sb.Append($", {nonAxisPolicy}");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class tensor_reduce_sync_impl {
public:
void reduce_group_sync() const noexcept {
@foreach(var comb in combinations) {
var reduce_group_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? "0" : "ntt::" + hierarchyNames[i] + "id()"));
var reduce_group_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => comb.Contains(i) ? "0" : "ntt::distributed::" + hierarchyNames[i] + "id()"));
@:if constexpr (Kind == tar::reduce_kind::@(GetName(comb, string.Empty))) {
@: tar::@(GetName(comb))(@(reduce_group_index)).arrive_and_wait();
@:}
Expand Down Expand Up @@ -151,7 +151,7 @@ class tensor_reduce_sync_impl {
}

@{
var cur_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => "ntt::" + hierarchyNames[i] + "id()"));
var cur_index = string.Join(", ", Enumerable.Range(0, hierarchy.Length).Select(i => "ntt::distributed::" + hierarchyNames[i] + "id()"));
}

template <class TIn, class TOut> void operator()(TIn &src, TOut &&dest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
#include <array>
#include <cstddef>

namespace nncase::ntt {
namespace nncase::ntt::distributed {
constexpr std::array<size_t, @(hierarchy.Length)> topology_dims = {@(string.Join(", ", hierarchy))};
}
23 changes: 21 additions & 2 deletions ntt/include/nncase/ntt/arch/cpu/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,27 @@
*/
#pragma once
#include "../../distributed.h"
#include "runtime.h"

namespace nncase::ntt::distributed {
template <> struct program_id_getter<topology::thread> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().tid;
}
};

template <> struct program_id_getter<topology::block> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().bid;
}
};

template <> struct program_id_getter<topology::chip> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().cid;
}
};

namespace nncase::ntt {
inline size_t tid() noexcept { return program_id<topology::thread>(); }
inline size_t bid() noexcept { return program_id<topology::block>(); }
inline size_t cid() noexcept { return program_id<topology::chip>(); }
Expand All @@ -25,4 +44,4 @@ inline constexpr size_t tdim() noexcept {
}
inline constexpr size_t bdim() noexcept { return program_dim(topology::block); }
inline constexpr size_t cdim() noexcept { return program_dim(topology::chip); }
} // namespace nncase::ntt
} // namespace nncase::ntt::distributed
20 changes: 0 additions & 20 deletions ntt/include/nncase/ntt/arch/cpu/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,6 @@ extern size_t bdim;
extern size_t cdim;
} // namespace nncase::ntt::runtime

namespace nncase::ntt {
template <> struct program_id_getter<topology::thread> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().tid;
}
};

template <> struct program_id_getter<topology::block> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().bid;
}
};

template <> struct program_id_getter<topology::chip> {
static size_t id() noexcept {
return runtime::cpu_thread_context_t::current().cid;
}
};
} // namespace nncase::ntt

extern "C" NTT_RUNTIME_API void
block_entry(const nncase::ntt::runtime::cpu_block_entry_params_t &params);
using block_entry_t = decltype(block_entry) *;
120 changes: 16 additions & 104 deletions ntt/include/nncase/ntt/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,131 +14,43 @@
*/
#pragma once
#include "arch/cpu/topology.h"
#include "nncase/ntt/shape.h"
#include "primitive_ops.h"
#include "tensor.h"
#include <cstddef>
#include <cstdint>
#include <utility>

#ifdef NNCASE_CPU_MODULE
#include <topology_def.h>
#endif

namespace nncase::ntt {
namespace nncase::ntt::distributed {
inline constexpr size_t topology_levels =
static_cast<size_t>(topology::count__);

#ifndef NNCASE_CPU_MODULE
constexpr size_t program_dim(topology /* topo */) noexcept { return 1; }
#else
constexpr size_t program_dim(topology topo) noexcept {
auto index =
static_cast<size_t>(topo) - (topology_levels - topology_dims.size());
return topology_dims[index];
int32_t index =
static_cast<int32_t>(topo) - (topology_levels - topology_dims.size());
return index < 0 ? 1 : topology_dims[index];
}
#endif

template <topology Scope = static_cast<topology>(topology_levels - 1)>
constexpr size_t topology_size() noexcept {
return [] {
size_t size = 1;
for (size_t i = 0; i <= static_cast<size_t>(Scope); i++) {
size *= program_dim(static_cast<topology>(i));
}
return size;
}();
}

template <topology Topology> struct program_id_getter {
static size_t id() noexcept;
};

template <topology Topology> size_t program_id() noexcept {
return program_id_getter<Topology>::id();
}

namespace dist_policy {
// Broadcast
struct B {
static constexpr size_t local_dim(size_t global_dim,
size_t /* topology_dim */,
size_t /* axis */) noexcept {
return global_dim;
}
};

// Split
template <size_t Axis> struct S {
static constexpr size_t axis = Axis;

static constexpr size_t local_dim(size_t global_dim, size_t topology_dim,
size_t axis) noexcept {
return axis == Axis ? ntt::ceil_div(global_dim, topology_dim)
: global_dim;
}
};

// Partial
struct P {
static constexpr size_t local_dim(size_t global_dim,
size_t /* topology_dim */,
size_t /* axis */) noexcept {
return global_dim;
}
};
} // namespace dist_policy

template <class... TPolicies> struct dist {
static constexpr std::tuple<TPolicies...> policies = {TPolicies{}...};
static constexpr size_t size = sizeof...(TPolicies);
};

namespace detail {
template <class TDist, topology Scope, class GlobalShape>
constexpr size_t get_local_dim(GlobalShape shape, size_t axis) noexcept {
auto local_dim = shape.at(axis);
auto cnt_topology =
static_cast<topology>(static_cast<size_t>(Scope) + 1 - TDist::size);
auto apply_policy = [&](auto policy) {
local_dim =
policy.local_dim(local_dim, program_dim(cnt_topology), axis);
cnt_topology =
static_cast<topology>(static_cast<size_t>(cnt_topology) + 1);
};
std::apply([&](auto... policies) { (apply_policy(policies), ...); },
TDist::policies);
return local_dim;
}

template <class TDist, topology Scope, class GlobalShape, size_t... Axes>
constexpr auto get_fixed_local_dim(GlobalShape,
std::index_sequence<Axes...>) noexcept {
return fixed_shape<get_local_dim<TDist, Scope>(GlobalShape{}, Axes)...>{};
}

template <class GlobalShape, class TDist, topology Scope>
struct local_shape_type {
using type = ranked_shape<GlobalShape::rank()>;
};

template <size_t... Dims, class TDist, topology Scope>
struct local_shape_type<fixed_shape<Dims...>, TDist, Scope> {
using type = decltype(get_fixed_local_dim<TDist, Scope>(
fixed_shape<Dims...>{}, std::make_index_sequence<sizeof...(Dims)>{}));
};
} // namespace detail

template <
class T, class GlobalShape, class TDist, topology Scope,
class LocalStrides = default_strides_t<
typename detail::local_shape_type<GlobalShape, TDist, Scope>::type>>
class dist_tensor_view
: public tensor_view<
T, typename detail::local_shape_type<GlobalShape, TDist, Scope>::type,
LocalStrides> {
public:
using local_tensor_type = tensor_view<
T, typename detail::local_shape_type<GlobalShape, TDist, Scope>::type,
LocalStrides>;

using local_tensor_type::local_tensor_type;

local_tensor_type &local() noexcept {
return static_cast<local_tensor_type &>(*this);
}

const local_tensor_type &local() const noexcept {
return static_cast<local_tensor_type &>(*this);
}
};
} // namespace nncase::ntt
} // namespace nncase::ntt::distributed
3 changes: 3 additions & 0 deletions ntt/include/nncase/ntt/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
#include "kernels/unpack.h"
#include "kernels/where.h"
#include "primitive_ops.h"
#include "remote_tensor.h"
#include "sharded_tensor.h"
#include "sharding.h"
#include "tensor.h"
#include "tensor_ops.h"
#include "ukernels.h"
Expand Down
26 changes: 26 additions & 0 deletions ntt/include/nncase/ntt/remote_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "sharding.h"
#include "tensor.h"

namespace nncase::ntt::distributed {
template <class T, class Shape, class Strides> class remote_tensor_view {
public:
static remote_tensor_view create(ranked_shape<topology_levels> program_ids,
T *local_address) noexcept;
};
} // namespace nncase::ntt::distributed
Loading

0 comments on commit ac83c88

Please sign in to comment.