Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working cublas batched example #60

Merged
merged 20 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 141 additions & 163 deletions codegen/annotations.h

Large diffs are not rendered by default.

133 changes: 48 additions & 85 deletions codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,9 @@ def client_rpc_write(self, f):
# array length operations are handled differently than char
elif isinstance(self.ptr, Array):
f.write(
" rpc_write(0, {param_name}, sizeof({param_type}[{length}])) < 0 ||\n".format(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.ptr.format().replace("[]", ""),
length=self.length.name,
param_type=self.parameter.name,
)
)
else:
Expand All @@ -216,7 +215,7 @@ def server_declaration(self) -> str:
c = self.ptr.const
self.ptr.const = False
# const[] isn't a valid part of a variable declaration
s = f" {self.ptr.format().replace("const[]", "")}* {self.parameter.name} = new {self.ptr.format().replace("const[]", "")}[{self.length.name}];\n"
s = f" {self.ptr.format().replace("const[]", "")}* {self.parameter.name} = nullptr;\n"
self.ptr.const = c
else:
c = self.ptr.ptr_to.const
Expand All @@ -230,17 +229,16 @@ def server_rpc_read(self, f):
return
elif isinstance(self.length, int):
f.write(
" rpc_read(conn, {param_name}, {size}) < 0 ||\n".format(
" rpc_read(conn, &{param_name}, {size}) < 0 ||\n".format(
param_name=self.parameter.name,
size=self.length,
)
)
elif isinstance(self.ptr, Array):
f.write(
" rpc_read(conn, {param_name}, sizeof({param_type}[{length}])) < 0 ||\n".format(
" rpc_read(conn, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.ptr.format().replace("[]", ""),
length=self.length.name,
)
)
else:
Expand All @@ -256,12 +254,6 @@ def server_rpc_read(self, f):
)
)

def server_len_rpc_read(self, f):
f.write(" if (rpc_read(conn, &{length_param}, sizeof(int)) < 0)\n".format(
length_param=self.length.name,
))
f.write(" return -1;\n")

@property
def server_reference(self) -> str:
return self.parameter.name
Expand Down Expand Up @@ -403,12 +395,20 @@ class OpaqueTypeOperation:
def client_rpc_write(self, f):
if not self.send:
return
f.write(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.type_.format(),
elif "const double*" in self.type_.format():
brodeynewman marked this conversation as resolved.
Show resolved Hide resolved
f.write(
" rpc_write(0, {param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.type_.format(),
)
)
else:
f.write(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.type_.format(),
)
)
)

@property
def server_declaration(self) -> str:
Expand All @@ -418,7 +418,10 @@ def server_declaration(self) -> str:
# but "const cudnnTensorDescriptor_t *xDesc" IS valid. This subtle change carries reprecussions.
elif "const " in self.type_.format() and not "void" in self.type_.format() and not "*" in self.type_.format():
return f" {self.type_.format().replace("const", "")} {self.parameter.name};\n"
else: return f" {self.type_.format()} {self.parameter.name};\n"
elif "const double*" in self.type_.format():
return f" double {self.parameter.name};\n"
else:
return f" {self.type_.format()} {self.parameter.name};\n"

def server_rpc_read(self, f):
if not self.send:
Expand All @@ -434,6 +437,8 @@ def server_rpc_read(self, f):
def server_reference(self) -> str:
if self.recv:
return f"&{self.parameter.name}"
if "const double*" in self.type_.format():
return f"&{self.parameter.name}"
return self.parameter.name

def server_rpc_write(self, f):
Expand Down Expand Up @@ -703,7 +708,15 @@ def main():

functions_with_annotations: list[tuple[Function, Function, list[Operation]]] = []

dupes = {}

for function in functions:
# ensure duplicate functions can't be written
if dupes.get(function.name.format()):
continue

dupes[function.name.format()] = True

try:
annotation = next(
f for f in annotations.namespace.functions if f.name == function.name
Expand Down Expand Up @@ -915,14 +928,6 @@ def main():
for function, annotation, operations, disabled in functions_with_annotations:
if function.name.format() in MANUAL_IMPLEMENTATIONS or disabled: continue

batched = False

# not a fan of this, but the batched functions are pretty standard with the flow below.
# batched functions are cublas functions that send pointer arrays where batchCount describes...
# the number of pointers in the arrays. This is non-trivial to generate.
if "Batched" in function.name.format():
batched = True

# parse the annotation doxygen
f.write(
"int handle_{name}(void *conn)\n".format(
Expand All @@ -933,70 +938,28 @@ def main():

defers = []

if batched:
array_batches = []
non_array_batches = []

for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
if isinstance(operation, ArrayOperation):
array_batches.append(operation)
if not isinstance(operation, ArrayOperation):
non_array_batches.append(operation)

# print our normal operations the same
for operation in operations:
if operation not in array_batches:
f.write(operation.server_declaration)

# do something with array batches
if len(array_batches) > 0 and hasattr(array_batches[0], "server_len_rpc_read"):
array_batches[0].server_len_rpc_read(f)

# pop here, because we already accounted for the batchCount integer
non_array_batches.pop(0)

for op in array_batches:
f.write(op.server_declaration)

f.write(" int request_id;\n")
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
else:
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))
for operation in operations:
f.write(operation.server_declaration)

f.write(" if (\n")
for operation in operations:
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))
f.write(" int request_id;\n")

f.write("\n")
# we only generate return from non-void types
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
else:
for operation in operations:
f.write(operation.server_declaration)

f.write(" int request_id;\n")
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))

# we only generate return from non-void types
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
f.write(" if (\n")
for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
else:
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))

f.write(" if (\n")
for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
else:
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))

f.write("\n")
f.write("\n")

f.write(
" request_id = rpc_end_request(conn);\n".format(
Expand Down
Loading
Loading