diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 790d80047e..63ce1492f9 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -152,6 +152,10 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object } ICHECK_EQ(this->model_metadata_.tensor_parallel_shards, num_shards); ICHECK_EQ(this->model_metadata_.pipeline_parallel_stages, num_stages); + // Invoke the CUDA graph allocation init function if it is defined. + if (cuda_graph_alloc_init_func_.defined()) { + this->cuda_graph_alloc_init_func_(); + } } ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device) { @@ -231,6 +235,7 @@ void FunctionTable::_InitFunctions() { this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); + this->cuda_graph_alloc_init_func_ = mod_get_func("cuda_graph_alloc_init"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (this->model_metadata_.sliding_window_size != -1 || !this->create_kv_cache_func_.defined()) { PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f2a513ec87..46fae540da 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -99,6 +99,7 @@ struct FunctionTable { PackedFunc apply_penalty_func_; PackedFunc apply_bitmask_func_; PackedFunc alloc_embedding_tensor_func_; + PackedFunc cuda_graph_alloc_init_func_; PackedFunc create_kv_cache_func_; PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; diff --git a/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py b/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py new file mode 100644 index 0000000000..70f6598852 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py @@ -0,0 +1,33 @@ +"""The pass that attaches an empty function for initialization.""" + +import tvm +from tvm import IRModule, relax + + +@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphAllocInitFunc") +class AttachCUDAGraphAllocInitFunc: # pylint: disable=too-few-public-methods + """Attach an empty function for initialization.""" + + def __init__(self): + pass + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + bb = relax.BlockBuilder(mod) + alloc_func_gv = None + for gv, _ in mod.functions_items(): + if gv.name_hint.startswith("cuda_graph_alloc"): + assert alloc_func_gv is None + alloc_func_gv = gv + if alloc_func_gv is None: + return mod + + with bb.function("cuda_graph_alloc_init", []): + bb.emit_func_output( + relax.op.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + args=[alloc_func_gv, relax.PrimValue(0)], + sinfo_args=relax.ObjectStructInfo(), + ) + ) + return bb.finalize() diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 363be1a59b..af1cf9f0e9 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -12,6 +12,7 @@ from mlc_llm.interface.compiler_flags import IPCAllReduceStrategyType from mlc_llm.support import logging +from .attach_cuda_graph_alloc_init_func import AttachCUDAGraphAllocInitFunc from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc @@ -159,7 +160,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I ), ScatterTupleGetItem(), PipelineParallelRewrite(), - _DebugDump("after-pipeline-rewrite.py", debug_dump, show_meta=False), tvm.relax.transform.RewriteDataflowReshape(), tvm.relax.transform.ToNonDataflow(), tvm.relax.transform.RemovePurityChecking(), @@ -172,6 +172,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.StaticPlanBlockMemory(), AttachMetadataWithMemoryUsage(metadata), tvm.relax.transform.RewriteCUDAGraph(), + AttachCUDAGraphAllocInitFunc(), tvm.relax.transform.LowerGPUIPCAllocStorage(), tvm.relax.transform.LowerAllocTensor(), tvm.relax.transform.KillAfterLastUse(),