Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the C code template in third_party/nvidia/backend/driver.py #4722

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sfzhu93
Copy link
Contributor

@sfzhu93 sfzhu93 commented Sep 13, 2024

Refactor the C code template in driver.py

Previously, in the third_party/nvidia/backend/driver.py, there was a long format string to define a C source code. In this commit, I moved the long format string in Python code into a separate C file, and use macros on the driver.py side to fill in the missing part in the C code.

This improves readability and makes it easy for future extension to the driver.py.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • This PR does not need a test because this PR doesn't introduce any new features or bug fixing.

  • I have not added any lit tests.

@minjang minjang marked this pull request as ready for review September 13, 2024 03:25
Copy link
Contributor

@minjang minjang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +163 to +197
def gen_c_def_macro(macro_name, macro_value):
return f"#define {macro_name} {macro_value}\n"

# macros to define:
"""
#define EXTRA_INNER_LAUNCH_PARAM_DECLS
#define INNER_LAUNCH_CUDA_CHECK_ARGS
#define LAUNCH_PY_ARGS
#define PY_ARG_FORMAT_STR
#define EXTRA_LAUNCH_PARSE_PY_ARGS
#define DEVICE_PTR_INFO_VARS
#define TMA_DESC_VARS
#define EXTRA_INNER_LAUNCH_CALL_ARGS
"""
macro_defs = gen_c_def_macro("EXTRA_INNER_LAUNCH_PARAM_DECLS", ", " + arg_decls if arg_decls else "")
macro_defs += gen_c_def_macro("INNER_LAUNCH_CUDA_CHECK_ARGS", ', '.join(f"&arg{i}" for i in params))
macro_defs += gen_c_def_macro("LAUNCH_PY_ARGS",
';'.join([f"{_extracted_type(ty)} _arg{i}" for i, ty in signature.items()]))
macro_defs += gen_c_def_macro("PY_ARG_FORMAT_STR", f'"{format}"')
macro_defs += gen_c_def_macro("EXTRA_LAUNCH_PARSE_PY_ARGS", ", " + args_list if args_list else "")
device_ptr_info_var_list = []
tma_desc_var_list = []
for i, ty in signature.items():
if ty[0] == "*":
device_ptr_info_var_list.append(
f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;")
elif ty == "nvTmaDesc":
tma_desc_var_list.append(f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;")

macro_defs += gen_c_def_macro("DEVICE_PTR_INFO_VARS", " \\\n".join(device_ptr_info_var_list))
macro_defs += gen_c_def_macro("TMA_DESC_VARS", " \\\n".join(tma_desc_var_list))
extra_inner_launch_call_args = ', '.join(internal_args_list)
macro_defs += gen_c_def_macro("EXTRA_INNER_LAUNCH_CALL_ARGS",
', ' + extra_inner_launch_call_args if extra_inner_launch_call_args else "")
src = macro_defs + Path(os.path.join(dirname, "cuda_launcher.c")).read_text()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I find this harder to read/modify than the previous version.
Can we only move the functions that don't need patching in a separate file and keep the rest in a template kind of way (and maybe break up functions if it helps moving more code out of python)

@gflegar
Copy link
Collaborator

gflegar commented Sep 19, 2024

Fly-by comment here, since this seems quite related to what we did about the launcher internally at Google:

I don't think we need python string interpolation at all, and we could completely precompile a generic version of the launcher. @ThomasRaoux would this be something interesting for you? If so, I could work on that (or if @sfzhu93 wants to take that on, I'm super happy to have that support!). I wanted to try upstreaming this, but didn't get to it yet. The best I have right now is this patch I can share (don't have a machine with command line git at the moment to make a proper PR out of it 😅): https://gist.github.com/gflegar/6fea5e50023a69e17c64e73eb0037cab. It should hopefully git apply -p2 launcher.patch on top of commit 174e780. The caveats are that:

  1. I don't have the CMake files for this figured out yet.
  2. I didn't get around to porting over the latest TMA stuff from [nvidia] Support passing TMA descriptors by-value #4498

(Sorry for not making this more polished, lots of stuff going on right now, I can definitely polish it more in a few weeks time, but just wanted to bring this option up since stuff seems to be moving on this.)

@ThomasRaoux
Copy link
Collaborator

Fly-by comment here, since this seems quite related to what we did about the launcher internally at Google:

I don't think we need python string interpolation at all, and we could completely precompile a generic version of the launcher. @ThomasRaoux would this be something interesting for you? If so, I could work on that (or if @sfzhu93 wants to take that on, I'm super happy to have that support!). I wanted to try upstreaming this, but didn't get to it yet. The best I have right now is this patch I can share (don't have a machine with command line git at the moment to make a proper PR out of it 😅): https://gist.github.com/gflegar/6fea5e50023a69e17c64e73eb0037cab. It should hopefully git apply -p2 launcher.patch on top of commit 174e780. The caveats are that:

  1. I don't have the CMake files for this figured out yet.
  2. I didn't get around to porting over the latest TMA stuff from [nvidia] Support passing TMA descriptors by-value #4498

(Sorry for not making this more polished, lots of stuff going on right now, I can definitely polish it more in a few weeks time, but just wanted to bring this option up since stuff seems to be moving on this.)

Interesting, did you measure the dispatch cost of doing that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants