Skip to content

Commit

Permalink
[C++] Invoke storage allocation for CUDA Graph explicitly (#3042)
Browse files Browse the repository at this point in the history
This PR adds a function that invokes the storage allocation function
generated by CUDA Graph rewrite. With this function, we now manually
trigger the storage allocation at initialization time.

The reason we need this is because that the storage allocation may
contain CUDA IPC memory alloc that has to run through a Disco session.
So when a function that needs CUDA graph storage allocation runs first
outside a Disco session, there might be error caused if we did not
initialize the allocation in advance.
  • Loading branch information
MasterJH5574 authored Nov 21, 2024
1 parent e349684 commit d23d6f5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
5 changes: 5 additions & 0 deletions cpp/serve/function_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object
}
ICHECK_EQ(this->model_metadata_.tensor_parallel_shards, num_shards);
ICHECK_EQ(this->model_metadata_.pipeline_parallel_stages, num_stages);
// Invoke the CUDA graph allocation init function if it is defined.
if (cuda_graph_alloc_init_func_.defined()) {
this->cuda_graph_alloc_init_func_();
}
}

ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device) {
Expand Down Expand Up @@ -231,6 +235,7 @@ void FunctionTable::_InitFunctions() {
this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true);
this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true);
this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor");
this->cuda_graph_alloc_init_func_ = mod_get_func("cuda_graph_alloc_init");
this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache");
if (this->model_metadata_.sliding_window_size != -1 || !this->create_kv_cache_func_.defined()) {
PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state");
Expand Down
1 change: 1 addition & 0 deletions cpp/serve/function_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct FunctionTable {
PackedFunc apply_penalty_func_;
PackedFunc apply_bitmask_func_;
PackedFunc alloc_embedding_tensor_func_;
PackedFunc cuda_graph_alloc_init_func_;
PackedFunc create_kv_cache_func_;
PackedFunc reset_kv_cache_func_;
bool support_backtracking_kv_;
Expand Down
33 changes: 33 additions & 0 deletions python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""The pass that attaches an empty function for initialization."""

import tvm
from tvm import IRModule, relax


@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphAllocInitFunc")
class AttachCUDAGraphAllocInitFunc: # pylint: disable=too-few-public-methods
"""Attach an empty function for initialization."""

def __init__(self):
pass

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""Entrypoint"""
bb = relax.BlockBuilder(mod)
alloc_func_gv = None
for gv, _ in mod.functions_items():
if gv.name_hint.startswith("cuda_graph_alloc"):
assert alloc_func_gv is None
alloc_func_gv = gv
if alloc_func_gv is None:
return mod

with bb.function("cuda_graph_alloc_init", []):
bb.emit_func_output(
relax.op.call_builtin_with_ctx(
"vm.builtin.cuda_graph.get_cached_alloc",
args=[alloc_func_gv, relax.PrimValue(0)],
sinfo_args=relax.ObjectStructInfo(),
)
)
return bb.finalize()
3 changes: 2 additions & 1 deletion python/mlc_llm/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mlc_llm.interface.compiler_flags import IPCAllReduceStrategyType
from mlc_llm.support import logging

from .attach_cuda_graph_alloc_init_func import AttachCUDAGraphAllocInitFunc
from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc
from .attach_logit_processor import AttachLogitProcessFunc
from .attach_sampler import AttachGPUSamplingFunc
Expand Down Expand Up @@ -159,7 +160,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
),
ScatterTupleGetItem(),
PipelineParallelRewrite(),
_DebugDump("after-pipeline-rewrite.py", debug_dump, show_meta=False),
tvm.relax.transform.RewriteDataflowReshape(),
tvm.relax.transform.ToNonDataflow(),
tvm.relax.transform.RemovePurityChecking(),
Expand All @@ -172,6 +172,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
tvm.relax.transform.StaticPlanBlockMemory(),
AttachMetadataWithMemoryUsage(metadata),
tvm.relax.transform.RewriteCUDAGraph(),
AttachCUDAGraphAllocInitFunc(),
tvm.relax.transform.LowerGPUIPCAllocStorage(),
tvm.relax.transform.LowerAllocTensor(),
tvm.relax.transform.KillAfterLastUse(),
Expand Down

0 comments on commit d23d6f5

Please sign in to comment.