Skip to content

Commit

Permalink
multiple resource groups
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 27, 2024
1 parent dfae17b commit 97ab2df
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 66 deletions.
54 changes: 37 additions & 17 deletions ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,25 @@ std::string Planner::Impl::plan(bool pretty) const {

auto get_context = [&](const ModelNodeRef &node,
const std::string &key) -> Json {
if (node->context.find(key) != node->context.end()) {
try {
return node->context.at(key);
} catch (const Json::out_of_range &e) {
}
return Json();
};

auto get_latest_context = [&](const ModelNodeRef &node,
const std::string &key) -> Json {
auto ctx = get_context(node, key);
if (ctx.empty()) return Json();
return ctx.back();
};

for (const auto &node : model_.nodes()) {
const auto &op = node->op;
if (op->is_virtual()) continue;

auto ctx_config = get_context(node, "Config");
auto ctx_config = get_latest_context(node, "Config");

Json config;
if (!ctx_config.empty()) {
Expand Down Expand Up @@ -223,8 +231,8 @@ std::string Planner::Impl::plan(bool pretty) const {
}

size_t granularity = config.value("Granularity", 1);
auto ctx_id = get_context(node, "Id");
auto ctx_sync = get_context(node, "Sync");
auto ctx_id = get_latest_context(node, "Id");
auto ctx_sync = get_latest_context(node, "Sync");
int id = ctx_id.empty() ? -1 : ctx_id.get<int>();
bool sync = ctx_sync.empty() ? true : ctx_sync.get<bool>();
if (id == prev_ctx_id && !sync) {
Expand All @@ -245,24 +253,31 @@ std::string Planner::Impl::plan(bool pretty) const {
task_info["Ops"][0]["Config"] = config;
task_infos.push_back(task_info);

auto ctx_processor_range = get_context(node, "ProcessorRange");
auto ctx_warp_range = get_context(node, "WarpRange");
auto ctx_sram_range = get_context(node, "SramRange");
auto ctx_processor_range_list = get_context(node, "ProcessorRange");
auto ctx_warp_range = get_latest_context(node, "WarpRange");
auto ctx_sram_range = get_latest_context(node, "SramRange");

Json processor_group;
if (!ctx_processor_range.empty()) {
Json resource_group;
bool new_processor_group = true;
if (ctx_processor_range_list.empty()) {
size_t num_processors = std::min(num_sm, num_tasks);
processor_group["ProcessorRange"] = {0, num_processors};
resource_group["ProcessorRange"] = {0, num_processors};
max_processor_id = std::max(max_processor_id, num_processors);
} else if (ctx_processor_range_list.size() == 1 ||
(id != prev_ctx_id)) {
auto &ctx_processor_range = ctx_processor_range_list[0];
processor_group["ProcessorRange"] = ctx_processor_range;
resource_group["ProcessorRange"] = ctx_processor_range;
max_processor_id = std::max(
max_processor_id, ctx_processor_range[1].get<size_t>());
} else {
size_t num_processors = std::min(num_sm, num_tasks);
processor_group["ProcessorRange"] = {0, num_processors};
max_processor_id = std::max(max_processor_id, num_processors);
new_processor_group = false;
resource_group["ProcessorRange"] =
ctx_processor_range_list.back();
}

Json resource_group;
resource_group["ProcessorRange"] =
processor_group["ProcessorRange"];
if (!ctx_warp_range.empty()) {
resource_group["WarpRange"] = ctx_warp_range;
max_warp_id =
Expand All @@ -280,9 +295,14 @@ std::string Planner::Impl::plan(bool pretty) const {
{"TaskRange", {0, num_tasks}},
{"Granularity", granularity}}};

processor_group["ResourceGroups"] = Json::array();
processor_group["ResourceGroups"].push_back(resource_group);
processor_groups.push_back(processor_group);
if (new_processor_group) {
processor_group["ResourceGroups"] = Json::array();
processor_group["ResourceGroups"].push_back(resource_group);
processor_groups.push_back(processor_group);
} else {
processor_groups.back()["ResourceGroups"].push_back(
resource_group);
}
}
prev_ctx_id = id;
first_op = false;
Expand Down
14 changes: 2 additions & 12 deletions ark/model/model_graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,8 @@ Json ModelGraphContextStack::get(const std::string &key) const {
return Json();
}

std::map<std::string, Json> ModelGraphContextStack::get_all() const {
std::map<std::string, Json> cur;
for (const auto &pair : this->storage_) {
if (!pair.second.empty()) {
cur[pair.first] = *pair.second.back();
}
}
return cur;
}

Json ModelGraphContextStack::dump() const {
Json j;
Json j = Json::object();
for (const auto &pair : this->storage_) {
j[pair.first] = Json::array();
for (const auto &value : pair.second) {
Expand Down Expand Up @@ -227,7 +217,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) {
producer->consumers.push_back(node);
}

node->context = context_stack_->get_all();
node->context = context_stack_->dump();

nodes_.push_back(node);
return node;
Expand Down
2 changes: 0 additions & 2 deletions ark/model/model_graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class ModelGraphContextStack {

Json get(const std::string &key) const;

std::map<std::string, Json> get_all() const;

Json dump() const;
};

Expand Down
2 changes: 1 addition & 1 deletion ark/model/model_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ModelNode {
UniqueList<ModelNodeRef> producers;

/// Graph context of this node.
std::map<std::string, Json> context;
Json context;
};

} // namespace ark
Expand Down
72 changes: 38 additions & 34 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,27 +495,29 @@ def forward(
scores = ark.tensor([bsz, self.n_local_heads, seqlen, seqlen], dtype=self.dtype)
scores_shards = ark.sharding(scores, axis=1, dim_per_shard=1)
results = []
with Context(
warp_range=[0, 8],
sram_range=[0, 49344],
sync=False,
config={
"NumWarps": 4,
"Granularity": 2,
"SramBytes": 24672,
"Tile": [256, 128],
},
):
with Context(processor_range=[0, 304]):
for i in range(len(scores_shards)):
xq_shard_reshaped = ark.reshape(xq_shards[i], [bsz, 1, seqlen, self.head_dim])
keys_shard_reshaped = ark.reshape(keys_shards[i], [bsz, 1, seqlen, self.head_dim])
scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen])
res = ark.matmul(xq_shard_reshaped, keys_shard_reshaped, scores_shard_reshaped, transpose_other=True)
res = ark.mul(res, 1.0 / math.sqrt(self.head_dim), res)
if mask is not None:
res = ark.add(res, mask, res)
with Context(
processor_range=[i*8, (i+1)*8],
warp_range=[0, 8],
sram_range=[0, 49344],
sync=False,
config={
"NumWarps": 4,
"Granularity": 2,
"SramBytes": 24672,
"Tile": [256, 128],
},
):
xq_shard_reshaped = ark.reshape(xq_shards[i], [bsz, 1, seqlen, self.head_dim])
keys_shard_reshaped = ark.reshape(keys_shards[i], [bsz, 1, seqlen, self.head_dim])
scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen])
res = ark.matmul(xq_shard_reshaped, keys_shard_reshaped, scores_shard_reshaped, transpose_other=True)
res = ark.mul(res, 1.0 / math.sqrt(self.head_dim), res)
if mask is not None:
res = ark.add(res, mask, res)
results.append(res)
scores = ark.identity(scores, deps=results)
scores = ark.identity(scores, deps=results)

def softmax(scores):
with Context(
Expand Down Expand Up @@ -546,22 +548,24 @@ def softmax(scores):
output_shards = ark.sharding(output, axis=2, dim_per_shard=1)

results = []
with Context(
warp_range=[0, 4],
sram_range=[0, 24672],
sync=False,
config={
"NumWarps": 4,
"SramBytes": 24672,
"Tile": [256, 128],
},
):
with Context(processor_range=[0, 304]):
for i in range(len(output_shards)):
values_shard_reshaped = ark.reshape(values_shards[i], [bsz, 1, seqlen, self.head_dim])
scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen])
output_shard_reshaped = ark.reshape(output_shards[i], [bsz, 1, seqlen, self.head_dim])
res = ark.matmul(scores_shard_reshaped, values_shard_reshaped, output_shard_reshaped)
results.append(res)
with Context(
processor_range=[i*8, (i+1)*8],
warp_range=[0, 4],
sram_range=[0, 24672],
sync=False,
config={
"NumWarps": 4,
"SramBytes": 24672,
"Tile": [256, 128],
},
):
values_shard_reshaped = ark.reshape(values_shards[i], [bsz, 1, seqlen, self.head_dim])
scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen])
output_shard_reshaped = ark.reshape(output_shards[i], [bsz, 1, seqlen, self.head_dim])
res = ark.matmul(scores_shard_reshaped, values_shard_reshaped, output_shard_reshaped)
results.append(res)
output = ark.identity(output, deps=results)
output = ark.reshape(
output, [bsz, seqlen, self.head_dim * self.n_local_heads]
Expand Down

0 comments on commit 97ab2df

Please sign in to comment.