Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brodey | cublas + cudnn codegen #58

Merged
merged 23 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
350c7d8
chore: bm
brodeynewman Oct 9, 2024
1ad672a
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 9, 2024
59150ee
chore: merge
brodeynewman Oct 9, 2024
c7d0b7d
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 11, 2024
29a919e
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 14, 2024
233b8e9
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 14, 2024
fc00189
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 17, 2024
79ccd26
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 23, 2024
38a351c
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Oct 29, 2024
ab2e209
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 6, 2024
ccd7c31
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 8, 2024
25cad41
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 9, 2024
11f8e43
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 11, 2024
8e3d836
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 18, 2024
e20c750
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Nov 28, 2024
232a327
chore: cublas + cudnn codegen updates
brodeynewman Nov 29, 2024
0c4bee0
chore: array
brodeynewman Nov 30, 2024
5be9851
chore: batched function
brodeynewman Dec 1, 2024
8f56379
Merge branch 'main' of github.com:kevmo314/scuda
brodeynewman Dec 2, 2024
3cfacec
chore: rm migrate compute type for now
brodeynewman Dec 2, 2024
427ce8a
Merge branch 'main' into brodey/more-codegen
brodeynewman Dec 2, 2024
fb519a9
chore: cleanup
brodeynewman Dec 3, 2024
2fdc3be
fix: type
brodeynewman Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions codegen/annotationgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@


def main():
options = ParserOptions(preprocessor=make_gcc_preprocessor())
options = ParserOptions(preprocessor=make_gcc_preprocessor(defines=["CUBLASAPI="]))

nvml_ast: ParsedData = parse_file("/usr/include/nvml.h", options=options)
cudnn_graph_ast: ParsedData = parse_file("/usr/include/cudnn_graph.h", options=options)
cudnn_ops_ast: ParsedData = parse_file("/usr/include/cudnn_ops.h", options=options)
cuda_ast: ParsedData = parse_file("/usr/include/cuda.h", options=options)
cublas_ast: ParsedData = parse_file("/usr/include/cublas_api.h", options=options)
cudart_ast: ParsedData = parse_file(
"/usr/include/cuda_runtime_api.h", options=options
)
Expand All @@ -17,6 +20,9 @@ def main():
nvml_ast.namespace.functions
+ cuda_ast.namespace.functions
+ cudart_ast.namespace.functions
+ cudnn_graph_ast.namespace.functions
+ cudnn_ops_ast.namespace.functions
+ cublas_ast.namespace.functions
)

with open("annotations.h", "a") as f:
Expand All @@ -39,21 +45,34 @@ def main():
)
)
f.write(" */\n")

params = []

for param in function.parameters:
if param.name and "[]" in param.type.format():
params.append(
"{type} {name}".format(
type=param.type.format().replace("[]", ""),
name=param.name + "[]",
)
)
elif param.name:
params.append(
"{type} {name}".format(
type=param.type.format(),
name=param.name,
)
)
else:
params.append(param.type.format())

joined_params = ", ".join(params)

f.write(
"{return_type} {name}({params});\n".format(
return_type=function.return_type.format(),
name=function.name.format(),
params=", ".join(
(
"{type} {name}".format(
type=param.type.format(),
name=param.name,
)
if param.name
else param.type.format()
)
for param in function.parameters
),
params=joined_params,
)
)

Expand Down
Loading
Loading