diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs index 85441998bf..9f8bd17ff9 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceBuiltn.cs @@ -40,7 +40,7 @@ public static class CSourceBuiltn #include using namespace nncase::ntt; using namespace nncase::ntt::distributed; -using namespace nncase::ntt::distributed::dist_policy; +using namespace nncase::ntt::distributed::shard_policy; "; diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs index b67ca8c90d..2ba6e5135a 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelCSourceConvertVisitor.cs @@ -526,7 +526,11 @@ protected override CSymbol VisitCall(Call expr) #if DEBUG_PRINT IndentScope.Writer.IndWrite($"runtime_util->printf(\"call {deviceFunc.Name} bid %d tid %d\\n\", bid, tid);\n"); #endif - var arguments = expr.Arguments.AsValueEnumerable().Select(Visit).ToArray(); + var arguments = expr.Arguments.AsValueEnumerable().Select(x => x switch + { + TIR.Buffer b => VisitBuffer(b, local: true), + _ => Visit(x), + }).ToArray(); _refFuncs.Add(deviceFunc); IndentScope.Writer.IndWrite($"{deviceFunc.Name}({string.Join(",", arguments.Select(arg => arg.Name))});\n"); } diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs index 4084888032..84c249313c 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/KernelUtility.cs @@ -80,8 +80,8 @@ public static string DistributedToC(DistributedType distributedType) } } - var nonAxisPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P" : "B"; - sb.Append('>'); + var implicitPolicy = ndSBP.Any(x => x is SBPPartialSum) ? "P" : "B"; + sb.Append($">, {implicitPolicy}"); for (int axis = 0; axis < distributedType.TensorType.Shape.Rank; axis++) { @@ -94,7 +94,7 @@ public static string DistributedToC(DistributedType distributedType) } else { - sb.Append($", {nonAxisPolicy}"); + sb.Append($", I"); } } diff --git a/ntt/include/nncase/ntt/sharding.h b/ntt/include/nncase/ntt/sharding.h index d98cec698d..57757a466d 100644 --- a/ntt/include/nncase/ntt/sharding.h +++ b/ntt/include/nncase/ntt/sharding.h @@ -18,9 +18,17 @@ #include "shape.h" namespace nncase::ntt::distributed { -namespace dist_policy { +namespace shard_policy { // Broadcast -struct B { +struct B {}; + +// Partial +template struct P { + static constexpr ntt::reduce_op reduce_op = ReduceOp; +}; + +// Implicit +struct I { template static constexpr size_t local_dim(size_t global_dim) noexcept { return global_dim; @@ -37,15 +45,7 @@ template struct S { return ntt::ceil_div(global_dim, divider); } }; - -// Partial -struct P { - template - static constexpr size_t local_dim(size_t global_dim) noexcept { - return global_dim; - } -}; -} // namespace dist_policy +} // namespace shard_policy template struct mesh { using shape_type = fixed_shape; @@ -57,11 +57,13 @@ template struct mesh { remote_program_id(ranked_shape index) noexcept; }; -template struct sharding { +template +struct sharding { using mesh_type = Mesh; + using implicit_policy_type = ImplicitPolicy; + using axis_policy_type = std::tuple; - static constexpr std::tuple policies = {Policies{}...}; - static constexpr size_t policies_size = sizeof...(Policies); + static constexpr size_t axis_policies_size = sizeof...(AxisPolicies); }; namespace detail { @@ -103,11 +105,11 @@ constexpr size_t get_submesh_start() noexcept { template constexpr size_t get_local_shard_dim(GlobalShape shape) noexcept { - static_assert(GlobalShape::rank() == Sharding::policies_size, + static_assert(GlobalShape::rank() == Sharding::axis_policies_size, "Invalid sharding."); auto local_dim = shape.at(Axis); - return std::get(Sharding::policies) + return std::get(typename Sharding::axis_policy_type{}) .template local_dim(local_dim); }