Skip to content

Commit

Permalink
Working cublas batched example (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
brodeynewman authored Dec 6, 2024
1 parent 2bbc5d2 commit a539495
Show file tree
Hide file tree
Showing 9 changed files with 4,864 additions and 1,127 deletions.
502 changes: 240 additions & 262 deletions codegen/annotations.h

Large diffs are not rendered by default.

122 changes: 37 additions & 85 deletions codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,9 @@ def client_rpc_write(self, f):
# array length operations are handled differently than char
elif isinstance(self.ptr, Array):
f.write(
" rpc_write(0, {param_name}, sizeof({param_type}[{length}])) < 0 ||\n".format(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.ptr.format().replace("[]", ""),
length=self.length.name,
param_type=self.parameter.name,
)
)
else:
Expand All @@ -216,7 +215,7 @@ def server_declaration(self) -> str:
c = self.ptr.const
self.ptr.const = False
# const[] isn't a valid part of a variable declaration
s = f" {self.ptr.format().replace("const[]", "")}* {self.parameter.name} = new {self.ptr.format().replace("const[]", "")}[{self.length.name}];\n"
s = f" {self.ptr.format().replace("const[]", "")}* {self.parameter.name} = nullptr;\n"
self.ptr.const = c
else:
c = self.ptr.ptr_to.const
Expand All @@ -230,17 +229,16 @@ def server_rpc_read(self, f):
return
elif isinstance(self.length, int):
f.write(
" rpc_read(conn, {param_name}, {size}) < 0 ||\n".format(
" rpc_read(conn, &{param_name}, {size}) < 0 ||\n".format(
param_name=self.parameter.name,
size=self.length,
)
)
elif isinstance(self.ptr, Array):
f.write(
" rpc_read(conn, {param_name}, sizeof({param_type}[{length}])) < 0 ||\n".format(
" rpc_read(conn, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.ptr.format().replace("[]", ""),
length=self.length.name,
)
)
else:
Expand All @@ -256,12 +254,6 @@ def server_rpc_read(self, f):
)
)

def server_len_rpc_read(self, f):
f.write(" if (rpc_read(conn, &{length_param}, sizeof(int)) < 0)\n".format(
length_param=self.length.name,
))
f.write(" return -1;\n")

@property
def server_reference(self) -> str:
return self.parameter.name
Expand Down Expand Up @@ -403,12 +395,13 @@ class OpaqueTypeOperation:
def client_rpc_write(self, f):
if not self.send:
return
f.write(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.type_.format(),
else:
f.write(
" rpc_write(0, &{param_name}, sizeof({param_type})) < 0 ||\n".format(
param_name=self.parameter.name,
param_type=self.type_.format(),
)
)
)

@property
def server_declaration(self) -> str:
Expand All @@ -418,7 +411,8 @@ def server_declaration(self) -> str:
# but "const cudnnTensorDescriptor_t *xDesc" IS valid. This subtle change carries reprecussions.
elif "const " in self.type_.format() and not "void" in self.type_.format() and not "*" in self.type_.format():
return f" {self.type_.format().replace("const", "")} {self.parameter.name};\n"
else: return f" {self.type_.format()} {self.parameter.name};\n"
else:
return f" {self.type_.format()} {self.parameter.name};\n"

def server_rpc_read(self, f):
if not self.send:
Expand Down Expand Up @@ -703,7 +697,15 @@ def main():

functions_with_annotations: list[tuple[Function, Function, list[Operation]]] = []

dupes = {}

for function in functions:
# ensure duplicate functions can't be written
if dupes.get(function.name.format()):
continue

dupes[function.name.format()] = True

try:
annotation = next(
f for f in annotations.namespace.functions if f.name == function.name
Expand Down Expand Up @@ -915,14 +917,6 @@ def main():
for function, annotation, operations, disabled in functions_with_annotations:
if function.name.format() in MANUAL_IMPLEMENTATIONS or disabled: continue

batched = False

# not a fan of this, but the batched functions are pretty standard with the flow below.
# batched functions are cublas functions that send pointer arrays where batchCount describes...
# the number of pointers in the arrays. This is non-trivial to generate.
if "Batched" in function.name.format():
batched = True

# parse the annotation doxygen
f.write(
"int handle_{name}(void *conn)\n".format(
Expand All @@ -933,70 +927,28 @@ def main():

defers = []

if batched:
array_batches = []
non_array_batches = []

for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
if isinstance(operation, ArrayOperation):
array_batches.append(operation)
if not isinstance(operation, ArrayOperation):
non_array_batches.append(operation)

# print our normal operations the same
for operation in operations:
if operation not in array_batches:
f.write(operation.server_declaration)

# do something with array batches
if len(array_batches) > 0 and hasattr(array_batches[0], "server_len_rpc_read"):
array_batches[0].server_len_rpc_read(f)

# pop here, because we already accounted for the batchCount integer
non_array_batches.pop(0)

for op in array_batches:
f.write(op.server_declaration)

f.write(" int request_id;\n")
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
else:
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))
for operation in operations:
f.write(operation.server_declaration)

f.write(" if (\n")
for operation in operations:
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))
f.write(" int request_id;\n")

f.write("\n")
# we only generate return from non-void types
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
else:
for operation in operations:
f.write(operation.server_declaration)
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))

f.write(" int request_id;\n")

# we only generate return from non-void types
if function.return_type.format() != "void":
f.write(" {return_type} scuda_intercept_result;\n".format(return_type=function.return_type.format()))
f.write(" if (\n")
for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
else:
f.write(" void* scuda_intercept_result;\n".format(return_type=function.return_type.format()))

f.write(" if (\n")
for operation in operations:
if isinstance(operation, NullTerminatedOperation):
if error := operation.server_rpc_read(f, len(defers)):
defers.append(error)
else:
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))
operation.server_rpc_read(f)
f.write(" false)\n")
f.write(" goto ERROR_{index};\n".format(index=len(defers)))

f.write("\n")
f.write("\n")

f.write(
" request_id = rpc_end_request(conn);\n".format(
Expand Down
Loading

0 comments on commit a539495

Please sign in to comment.