Skip to content

Commit

Permalink
Change weight to channel-packing in Conv1d
Browse files Browse the repository at this point in the history
Differential Revision: D66417572

Pull Request resolved: #7057
  • Loading branch information
yipjustin authored Nov 26, 2024
1 parent a35cb73 commit 2967302
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
32 changes: 17 additions & 15 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,25 @@ void main() {
// "k" tracks the kernel's index for our input-kernel computation.
// It reads out-of-bound zeros, but trying to avoid them complicates
// for-loop conditions, which results in worse performance.
for (int k = 0; k < kernel_size; k += 4) {
// Since the weight tensor is width-packed, which is along the length
// dimension, we can batch-read four elements at a time.
const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);

ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);

in_pos[in_axis_map.x] += dilation;
sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);
// The weight tensor is channel-packed. It may not be trival choice for
// performance reason since need to have more data fetch. The reason is
// for some sequence model, we found that the weight tensor
// (out_channel, in_channel / group, kernel) often has a large
// out_channel >> kernel, leading to non-optimal use of memory as the
// weight tensor gets very deep. As a mitigation, we use channel-packing
// for the weight tensor, yielding a 75% reduction in weight-tensor
// memory.

// It is possible to further reduce the memory footprint by swapping the
// dimensions, using x extent for out_channel, and y for kernel.
for (int k = 0; k < kernel_size; k += 1) {
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);

in_pos[in_axis_map.x] += dilation;
sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);

in_pos[in_axis_map.x] += dilation;
sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
sum = fma(weight, load_texel(t_in, in_pos), sum);
}
}

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ void add_conv1d_node(
const ValueRef out,
const bool clamp_out) {
ValueRef arg_weight = prepack_standard(
graph, weight, graph.storage_type_of(out), utils::kWidthPacked);
graph, weight, graph.storage_type_of(out), utils::kChannelsPacked);
ValueRef arg_bias = prepack_biases(
graph,
bias,
Expand Down

0 comments on commit 2967302

Please sign in to comment.