Skip to content

Commit

Permalink
Brodey | cublas + cudnn codegen (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
brodeynewman authored Dec 3, 2024
1 parent d338c85 commit 2bbc5d2
Show file tree
Hide file tree
Showing 6 changed files with 40,121 additions and 2,755 deletions.
51 changes: 38 additions & 13 deletions codegen/annotationgen.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from cxxheaderparser.simple import parse_file, ParsedData, ParserOptions
from cxxheaderparser.preprocessor import make_gcc_preprocessor
from cxxheaderparser.types import Type, Pointer
from cxxheaderparser.types import Type, Pointer, Array


def main():
options = ParserOptions(preprocessor=make_gcc_preprocessor())
options = ParserOptions(preprocessor=make_gcc_preprocessor(defines=["CUBLASAPI="]))

nvml_ast: ParsedData = parse_file("/usr/include/nvml.h", options=options)
cudnn_graph_ast: ParsedData = parse_file("/usr/include/cudnn_graph.h", options=options)
cudnn_ops_ast: ParsedData = parse_file("/usr/include/cudnn_ops.h", options=options)
cuda_ast: ParsedData = parse_file("/usr/include/cuda.h", options=options)
cublas_ast: ParsedData = parse_file("/usr/include/cublas_api.h", options=options)
cudart_ast: ParsedData = parse_file(
"/usr/include/cuda_runtime_api.h", options=options
)
Expand All @@ -17,6 +20,9 @@ def main():
nvml_ast.namespace.functions
+ cuda_ast.namespace.functions
+ cudart_ast.namespace.functions
+ cudnn_graph_ast.namespace.functions
+ cudnn_ops_ast.namespace.functions
+ cublas_ast.namespace.functions
)

with open("annotations.h", "a") as f:
Expand All @@ -38,22 +44,41 @@ def main():
name=param.name, type=param.type.format()
)
)
elif isinstance(param.type, Array):
f.write(
" * @param {name} SEND_ONLY\n".format(
name=param.name, type=param.type.format()
)
)
f.write(" */\n")

params = []

for param in function.parameters:
if param.name and "[]" in param.type.format():
params.append(
"{type} {name}".format(
type=param.type.format().replace("[]", ""),
name=param.name + "[]",
)
)
elif param.name:
params.append(
"{type} {name}".format(
type=param.type.format(),
name=param.name,
)
)
else:
params.append(param.type.format())

joined_params = ", ".join(params)

f.write(
"{return_type} {name}({params});\n".format(
return_type=function.return_type.format(),
name=function.name.format(),
params=", ".join(
(
"{type} {name}".format(
type=param.type.format(),
name=param.name,
)
if param.name
else param.type.format()
)
for param in function.parameters
),
params=joined_params,
)
)

Expand Down
Loading

0 comments on commit 2bbc5d2

Please sign in to comment.