Skip to content

Commit

Permalink
Merge pull request #6 from jax-ml/multidevice
Browse files Browse the repository at this point in the history
Lazily load CUDA modules before launch
  • Loading branch information
sharadmv authored Sep 13, 2022
2 parents f4e589d + 47498b3 commit 78738f4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 121 deletions.
1 change: 0 additions & 1 deletion examples/matrix_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import jax
import jax.numpy as jnp
import math

m=512
n=512
Expand Down
43 changes: 25 additions & 18 deletions jax_triton/triton_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,21 @@
from jax.interpreters import mlir
from jax import tree_util
from jax._src import util
from jax._src.lib import xla_bridge as xb
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import numpy as np
import torch
import triton
import triton.language as tl

from jax_triton import custom_call
from jax_triton import triton_kernel_call

os.environ["TRITON_CACHE_DIR"] = ""
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

xc.register_custom_call_target("triton_call", custom_call.get_custom_call(), platform="CUDA")
xc.register_custom_call_target("triton_kernel_call", triton_kernel_call.get_custom_call(), platform="CUDA")

def get_triton_type(obj: Any) -> str:
type_map = {
Expand Down Expand Up @@ -79,8 +80,6 @@ def get_triton_python_ir(aval):

def compile(triton_function, constants, *, key, device=0, num_warps=4, num_stages=2):
def lower(*args):
arg_types = [get_triton_python_ir(a) for a in args]
attributes = {i: 16 for i in range(len(args))}
triton_function._warmup(arg_types=arg_types, device=device,
attributes=attributes, constants=constants, num_warps=num_warps,
num_stages=num_stages, key=key, is_manual_warmup=True,
Expand Down Expand Up @@ -133,18 +132,24 @@ def aval_to_layout(aval):
arange = np.arange(aval.ndim, dtype='int64')[::-1].copy()
return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get())

def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stages,
def emit_triton_call(ctx, triton_func, grid, num_warps, num_stages,
dump_binary_path: Optional[str], **metaparams):
metadata = {triton_func.arg_names.index(k) : v for k, v in metaparams.items()}
compile(triton_func, metadata, num_warps=num_warps, num_stages=num_stages, key="foo")(*avals_in, *avals_out)
loaded_binary = triton_func.bin_cache["foo"]
kernel_ptr = loaded_binary.kernel
shared_mem = loaded_binary.shared_mem
all_args = [*ctx.avals_in, *ctx.avals_out]
arg_types = [get_triton_python_ir(a) for a in all_args]
attributes = {i: 16 for i in range(len(all_args))}
# TODO(sharadmv): handle multiple devices, right now we assume device 0 which
# is fine when we have multiple of the same GPU but this won't work in
# general.
binary = triton_func._compile(arg_types=arg_types, device=0,
attributes=attributes, constants=metadata, num_warps=num_warps,
num_stages=num_stages, extern_libs={})
name, asm, shared_mem = binary.name, binary.asm, binary.shared_mem
if dump_binary_path is not None:
binary = dict(
asm=loaded_binary.asm,
asm=asm,
shared_mem=shared_mem,
name=loaded_binary.bin.name)
name=name)
with open(dump_binary_path, "wb") as fp:
pickle.dump(binary, fp)

Expand All @@ -158,22 +163,24 @@ def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stag
grid_1, grid_2 = grid_[1], grid_[2]
else:
assert False
arity = len(avals_in) + len(avals_out)
descriptor = custom_call.make_triton_call_descriptor(kernel_ptr, shared_mem, grid_0, grid_1, grid_2, num_warps, arity)
return descriptor
arity = len(ctx.avals_in) + len(ctx.avals_out)
descriptor, keepalive = triton_kernel_call.make_triton_call_descriptor(
name, asm, shared_mem, grid_0, grid_1, grid_2, num_warps, arity)
return descriptor, keepalive

def triton_call_lowering(ctx, *args, kernel, out_shapes, grid, num_warps=4, num_stages=2,
dump_binary_path: Optional[str], **metaparams):
out_type = ir.TupleType.get_tuple([
ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype))
for out_shape in out_shapes])
i32_type = ir.IntegerType.get_signless(32)
descriptor = emit_triton_call(kernel, ctx.avals_in, ctx.avals_out, grid,
num_warps, num_stages, dump_binary_path,
**metaparams)
descriptor, keepalive = emit_triton_call(ctx, kernel, grid,
num_warps, num_stages, dump_binary_path,
**metaparams)
ctx.module_context.add_keepalive(keepalive)
out = mhlo.CustomCallOp(
[out_type], args,
call_target_name=ir.StringAttr.get("triton_call"),
call_target_name=ir.StringAttr.get("triton_kernel_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(descriptor),
api_version=ir.IntegerAttr.get(i32_type, 1),
Expand Down
100 changes: 0 additions & 100 deletions lib/custom_call.cc

This file was deleted.

135 changes: 135 additions & 0 deletions lib/triton_kernel_call.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/* Copyright 2022 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "triton_kernel_call.h"

#include <iostream>
#include <cassert>
#include <string>

#include <pybind11/pybind11.h>
#include "cuda.h"

namespace py = pybind11;

namespace jax_triton {

const int TRITON_MAX_N_SHARED_BYTES = 49152;
const int TRITON_MAX_SHARED_OPTIN = 49152;


void TritonExecutable::launch(CUstream stream, void** buffers) {
CUdevice dev;
CUcontext ctx;
// Set the current context to the stream context so we can query the stream
// device
cuStreamGetCtx(stream, &ctx);
cuCtxSetCurrent(ctx);
/// Only load the kernel if it hasn't already been loaded for this device
cuCtxGetDevice(&dev);
CUfunction kernel = load(dev);
std::string params;
params.resize(8 * arity);
char* params_ptr = &params[0];
for (uint32_t i = 0; i < arity; i++) {
params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
std::memcpy(params_ptr, &buffers[i], 8);
params_ptr += 8;
}
size_t params_size = static_cast<size_t>(params_ptr - &params[0]);
void* config[] = {
CU_LAUNCH_PARAM_BUFFER_POINTER,
static_cast<void*>(const_cast<char*>(params.data())),
CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
CU_LAUNCH_PARAM_END
};
CUresult result = cuLaunchKernel(kernel, grid_0, grid_1, grid_2, num_warps * 32, 1, 1, shared_mem, stream, nullptr, config);
if (result != 0) {
std::cout << "Failed launch: " << result << std::endl;
}
};

CUfunction TritonExecutable::load(CUdevice device) {
const std::lock_guard<std::mutex> lock(mut);
if (is_loaded(device)) {
return kernels[device];
}
// Mimics Triton kernel loading
std::string assembly;
auto iter = asm_map.find("cubin");
if (iter != asm_map.end())
assembly = py::cast<std::string>(asm_map["cubin"]);
else {
assert(asm_map.contains("ptx"));
assembly = py::cast<std::string>(asm_map["ptx"]);
}
CUfunction fun;
CUmodule mod;
cuModuleLoadData(&mod, assembly.c_str());
cuModuleGetFunction(&fun, mod, name.c_str());
int n_regs = 0;
int n_spills = 0;
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
n_spills /= 4;
int shared_optin;
cuDeviceGetAttribute(&shared_optin,
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
device);
if (shared_mem > TRITON_MAX_N_SHARED_BYTES &&
shared_optin > TRITON_MAX_SHARED_OPTIN) {
cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
int shared_total, shared_static;
cuDeviceGetAttribute(
&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
device);
cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
fun);
cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static);
}
kernels[device] = fun;
return fun;
};

void do_custom_call(CUstream stream, void** buffers,
char* opaque, size_t opaque_len) {
uint64_t descriptor = std::strtoull(opaque, NULL, 0);
TritonExecutable* executable = TritonExecutable::from_descriptor(descriptor);
executable->launch(stream, buffers);
}

std::pair<std::string, py::object> MakeTritonExecutable(std::string name, asm_map_t asm_map, uint32_t shared_mem, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, uint32_t num_warps, uint32_t arity) {
auto triton_call = std::make_unique<TritonExecutable>(
name, asm_map, shared_mem, grid_0, grid_1, grid_2, num_warps, arity);
std::string descriptor = std::to_string(reinterpret_cast<uint64_t>(triton_call.get()));
py::capsule callback_capsule(triton_call.release(), [](void* ptr) {
delete reinterpret_cast<TritonExecutable*>(ptr);
});
return std::make_pair(descriptor, py::object(std::move(callback_capsule)));
}

template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
return pybind11::capsule(reinterpret_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}

PYBIND11_MODULE(triton_kernel_call, m) {
m.def("make_triton_call_descriptor", &MakeTritonExecutable);
m.def("get_custom_call", [](){
return EncapsulateFunction(do_custom_call);
});
}

} // namespace jax_triton
Loading

0 comments on commit 78738f4

Please sign in to comment.