Skip to content

Commit

Permalink
Add implicit shard_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Nov 21, 2024
1 parent ac83c88 commit f847eba
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static class CSourceBuiltn
#include <nncase/ntt/ntt.h>
using namespace nncase::ntt;
using namespace nncase::ntt::distributed;
using namespace nncase::ntt::distributed::dist_policy;
using namespace nncase::ntt::distributed::shard_policy;
";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,11 @@ protected override CSymbol VisitCall(Call expr)
#if DEBUG_PRINT
IndentScope.Writer.IndWrite($"runtime_util->printf(\"call {deviceFunc.Name} bid %d tid %d\\n\", bid, tid);\n");
#endif
var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray();
var arguments = expr.Arguments.AsValueEnumerable().Select(x => x switch
{
TIR.Buffer b => VisitBuffer(b, local: true),
_ => Visit(x),
}).ToArray();
_refFuncs.Add(deviceFunc);
IndentScope.Writer.IndWrite($"{deviceFunc.Name}({string.Join(",", arguments.Select(arg => arg.Name))});\n");
}
Expand Down
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public static string DistributedToC(DistributedType distributedType)
}
}

var nonAxisPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P" : "B";
sb.Append('>');
var implicitPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P<reduce_op::sum>" : "B";
sb.Append($">, {implicitPolicy}");

for (int axis = 0; axis < distributedType.TensorType.Shape.Rank; axis++)
{
Expand All @@ -94,7 +94,7 @@ public static string DistributedToC(DistributedType distributedType)
}
else
{
sb.Append($", {nonAxisPolicy}");
sb.Append($", I");
}
}

Expand Down
34 changes: 18 additions & 16 deletions ntt/include/nncase/ntt/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@
#include "shape.h"

namespace nncase::ntt::distributed {
namespace dist_policy {
namespace shard_policy {
// Broadcast
struct B {
struct B {};

// Partial
template <reduce_op ReduceOp> struct P {
static constexpr ntt::reduce_op reduce_op = ReduceOp;
};

// Implicit
struct I {
template <class Mesh>
static constexpr size_t local_dim(size_t global_dim) noexcept {
return global_dim;
Expand All @@ -37,15 +45,7 @@ template <size_t... Axes> struct S {
return ntt::ceil_div(global_dim, divider);
}
};

// Partial
struct P {
template <class Mesh>
static constexpr size_t local_dim(size_t global_dim) noexcept {
return global_dim;
}
};
} // namespace dist_policy
} // namespace shard_policy

template <topology Scope, size_t... Dims> struct mesh {
using shape_type = fixed_shape<Dims...>;
Expand All @@ -57,11 +57,13 @@ template <topology Scope, size_t... Dims> struct mesh {
remote_program_id(ranked_shape<shape_type::rank()> index) noexcept;
};

template <class Mesh, class... Policies> struct sharding {
template <class Mesh, class ImplicitPolicy, class... AxisPolicies>
struct sharding {
using mesh_type = Mesh;
using implicit_policy_type = ImplicitPolicy;
using axis_policy_type = std::tuple<AxisPolicies...>;

static constexpr std::tuple<Policies...> policies = {Policies{}...};
static constexpr size_t policies_size = sizeof...(Policies);
static constexpr size_t axis_policies_size = sizeof...(AxisPolicies);
};

namespace detail {
Expand Down Expand Up @@ -103,11 +105,11 @@ constexpr size_t get_submesh_start() noexcept {

template <class Sharding, size_t Axis, class GlobalShape>
constexpr size_t get_local_shard_dim(GlobalShape shape) noexcept {
static_assert(GlobalShape::rank() == Sharding::policies_size,
static_assert(GlobalShape::rank() == Sharding::axis_policies_size,
"Invalid sharding.");

auto local_dim = shape.at(Axis);
return std::get<Axis>(Sharding::policies)
return std::get<Axis>(typename Sharding::axis_policy_type{})
.template local_dim<typename Sharding::mesh_type>(local_dim);
}

Expand Down

0 comments on commit f847eba

Please sign in to comment.