From effbb7b98f23c8ce724109bac9593b88ba2047b1 Mon Sep 17 00:00:00 2001 From: Christina Koutsou <74819775+kchristin22@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:42:51 +0200 Subject: [PATCH] Add cudaMemset call after cudaMalloc for derivative pointers (#1129) --- lib/Differentiator/ReverseModeVisitor.cpp | 12 ++++++++++++ test/CUDA/GradientKernels.cu | 1 + 2 files changed, 13 insertions(+) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 11c0c1981..ee5b8e259 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1766,6 +1766,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, llvm::MutableArrayRef(DerivedCallArgs), Loc) .get(); + if (FD->getNameAsString() == "cudaMalloc") { + if (auto* addrOp = dyn_cast(DerivedCallArgs[0])) + if (addrOp->getOpcode() == UO_AddrOf) + DerivedCallArgs[0] = addrOp->getSubExpr(); // get the pointer + + llvm::SmallVector args = {DerivedCallArgs[0], + getZeroInit(m_Context.IntTy), + DerivedCallArgs[1]}; + addToCurrentBlock(call_dx, direction::forward); + addToCurrentBlock(GetFunctionCall("cudaMemset", "", args)); + call_dx = nullptr; + } return StmtDiff(call, call_dx); } // For calls to C-style memory deallocation functions, we do not need to diff --git a/test/CUDA/GradientKernels.cu b/test/CUDA/GradientKernels.cu index 92f69d92c..fcbae3b5b 100644 --- a/test/CUDA/GradientKernels.cu +++ b/test/CUDA/GradientKernels.cu @@ -451,6 +451,7 @@ double fn_memory(double *out, double *in) { //CHECK-NEXT: double *_d_in_dev = nullptr; //CHECK-NEXT: double *in_dev = nullptr; //CHECK-NEXT: cudaMalloc(&_d_in_dev, 10 * sizeof(double)); +//CHECK-NEXT: cudaMemset(_d_in_dev, 0, 10 * sizeof(double)); //CHECK-NEXT: cudaMalloc(&in_dev, 10 * sizeof(double)); //CHECK-NEXT: cudaMemcpy(in_dev, in, 10 * sizeof(double), cudaMemcpyHostToDevice); //CHECK-NEXT: kernel_call<<<1, 10>>>(out, in_dev);