Skip to content

Commit

Permalink
[XLA:CPU] Fix the bug in transposed convolution async execution.
Browse files Browse the repository at this point in the history
In case of async execution, the intermediate buffer was out-of-scope when callback was called, resulting in reading an already released memory. Now the ownership of the buffer is transferred to the lambda object, extending the lifetime of the buffer.

PiperOrigin-RevId: 700377235
  • Loading branch information
Adam-Banas authored and tensorflower-gardener committed Nov 26, 2024
1 parent 1fbb217 commit d6d8f52
Showing 1 changed file with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,24 @@ void EigenTransposedConv2D(
const int output_offset = output_image_size * kernel_filters;

// Pack the calculated patches into the output buffer.
auto pack_patches = [=]() mutable {
// NOTE: The ownership of the col_buffer is transferred to the lambda without
// data copy or reallocation. Thanks to that, col_buffer_data pointer remains
// valid, and that is important because 'C' matrix is referencing it. We need
// to make sure this lambda is never copied, otherwise col_buffer won't
// contain contraction results at the time lambda is called.
auto pack_patches = [=, col_buffer = std::move(col_buffer)]() {
// Using local pointers to buffers, because lambda is not mutable.
const ScalarType* col_buffer_data = col_buffer.data();
ScalarType* local_out_data = out_data;

// TODO(adambanas): Run this part in parallel.
for (int image_id = 0; image_id < input_batch; ++image_id) {
Pack1DPatches<ScalarType>(col_buffer_data, kernel_filters, output_y,
kernel_y, padding_y_before, padding_y_after,
lhs_y_dilation, out_data);
lhs_y_dilation, local_out_data);

col_buffer_data += input_offset;
out_data += output_offset;
local_out_data += output_offset;
}

// If done callback is provided, we need to call it after all the work is
Expand All @@ -170,7 +179,7 @@ void EigenTransposedConv2D(

if (done_callback) {
// Schedule the work in the thread pool and return..
C.device(device, pack_patches) = A.contract(B, contract_dims);
C.device(device, std::move(pack_patches)) = A.contract(B, contract_dims);
} else {
// Run synchronously in the current thread.
C.device(device) = A.contract(B, contract_dims);
Expand Down

0 comments on commit d6d8f52

Please sign in to comment.