diff --git a/docs/nnue.md b/docs/nnue.md index 27a56d09..ba02569a 100644 --- a/docs/nnue.md +++ b/docs/nnue.md @@ -93,6 +93,8 @@ What this document DOES NOT contain: * [m256_haddx4](#m256_haddx4) - [Linear layer with sparse input](#linear-layer-with-sparse-input) * [m256_process_chunk](#m256_process_chunk) + - [Linear layer with blocked sparse input](#linear-layer-with-blocked-sparse-input) + * [Helping the odds](#helping-the-odds) - [Linear layer with sparse input, alternative approach](#linear-layer-with-sparse-input-alternative-approach) * [m256_process_chunk_alternative](#m256_process_chunk_alternative) - [Linear layer with sparse input and blocked sparse output](#linear-layer-with-sparse-input-and-blocked-sparse-output) @@ -1633,6 +1635,127 @@ inline void m256_process_chunk_alternative( } ``` +#### Linear layer with blocked sparse input + +If you read the previous two sections you realize that this approach can get complicated. The complication stems from the fact that in the naive way we have to process every row separately. But what if we grouped the inputs into chunks that are easier to work with? At first, this may appear as a pessimisation, because such an approach would have to consider an input to be non-zero if any of the values in a group is non-zero. + +Anyway, let's consider a group of size 4. With basic probability math we can calculate that if we group by 4, and the chance of a single input to be zero is `x`, then the chance of all 4 inputs to be zero is `x^4`. For example, if `x = 0.9` then `x^4 ~= 0.65` - that's almost 4 times as many indices to process... BUT. There's also 4 times fewer indices now, because they are grouped! What about the amount of work required for each non-zero input? Normally it would be 4 times as much, because inputs are grouped by 4. However, two things align to help our cause. First, as we saw, the implementation is clunky for group size of 1, and requires either more memory for the weights or additional work to unpack them. Second, It prevents us from efficiently handling a small amount of outputs, because we're limited by SIMD register width. Stockfish, for example, uses only 16 outputs after the large layer, so processing multiple (4) inputs at a time is a natural optimization. + +So, overall, we have the following tradeoffs: + +1. 4 times fewer indices to calculate +2. 4 times more inputs to process per index +3. but cheaper, we can do simple and fast processing for each input chunk again (no weight unpacking, no int16 weights), especially with less outputs + +Combined, it gives a sizable speedup for larger networks. + +Let's see the rough code. + +```cpp +int lsb(std::uint32_t v) { + // returns the least significant set bit in v + // implementation detail + // can be implemented for example using compiler intrinsics + // https://www.chessprogramming.org/BitScan#Leading_Zero_Count +} + +// 4 outputs per input +constexpr int ChunkSize = 4; + +// We will be processing 4 inputs at a time, so to do it efficiently we need to permute the weights. +// Figuring out why this permutation is like this is left as an excercise to the reader. +int get_weight_index_scrambled(const LinearLayer& layer, int i) +{ + return + (i / ChunkSize) % (layer.num_inputs / ChunkSize) * layer.num_outputs * ChunkSize + + i / layer.num_inputs * ChunkSize + + i % ChunkSize; +} + +void load_weights( + const LinearLayer& layer, + const int8_t* data +) { + for (int i = 0; i < layer.num_outputs * layer.num_inputs; ++i) { + layer.weights[get_weight_index_scrambled(i)] = data[i]; + } +} + +int32_t* linear_sparse_input( + const LinearLayer& layer, + int32_t* output, + const int8_t* input +) { + static_assert(is_same_v, + "This approach requires weights to be 8 bit."); + + constexpr int register_width = 256 / 8; + constexpr int input_register_width = register_width; // uint8_t + constexpr int output_register_width = register_width / 4; // int32_t + assert(layer.num_inputs % input_register_width == 0); + + // We need to find out the indices of the input values that are non-zero. + // Remember that we group the inputs by 4, so comparisons now use epi32. + uint16_t nnz_input_indices[layer.num_inputs / ChunkSize]; + int num_nnz_input_indices = 0; + + for (int i = 0; i < layer.num_inputs; i += input_register_width) { + const __m256i input_chunk = _mm256_load_si256(input + i); + // Find out where the values are greater than 0 and set the corresponding bits in nnz + // Annoyingly, we have to use _ps, because _epi32 doesn't exist for this instruction. + // This does incur some performance penalty due to domain change. + _mm256_movemask_ps((__m256) + _mm256_cmpgt_epi32(input_chunk, _mm256_setzero_si256()) + ); + + // Extract the indices of the set bits in nnz + while (nnz) { + const int lsb_index = lsb(nnz); + nnz &= nnz - 1; // reset the least significant set bit in nnz + nnz_input_indices[num_nnz_input_indices++] = i + lsb_index; + } + } + + // This time we will hold all outputs in registers, since we don't expect many of them. + const int num_regs = layer.num_outputs / output_register_width; + __m256i acc[num_regs]; + + // Initialize the accumulators with biases. + const __m256i* biasvec = reinterpret_cast(layer.biases); + for (int k = 0; k < num_regs; ++k) + acc[k] = biasvec[k]; + + // We will be loading inputs 4 at a time. + const auto input32 = reinterpret_cast(input); + + // We process one chunk at a time, but it's possible to unroll with some potential gains. + for (int i = 0; i < num_nnz_input_indices; ++i) { + const int input_id = nnz_input_indices[i]; + // We load 4 inputs at a time. + const __m256i factor = _mm256_set1_epi32(input32[input_id]); + + // Find the corresponding weights. + const auto col = reinterpret_cast(&weights[input_id * ChunkSize * layer.num_outputs]); + + // See how simple this part got now?! + // Back to our old and trusted m256_add_dpbusd_epi32. Horizontal accumulation for free! + for (int k = 0; k < num_regs; ++k) + m256_add_dpbusd_epi32(acc[k], factor, col[k]); + } + + // Store the accumulators directly into the output + __m256i* outptr = reinterpret_cast<__m256i*>(output); + for (int k = 0; k < num_regs; ++k) + outptr[k] = acc[k]; + + return output + layer.num_outputs; +} +``` + +##### Helping the odds + +The math of `x^4` assumes uniform distribution of non-zero inputs. We, however, help it a little bit by reordering the weights such that values that are more likely to be non-zero are grouped together (say, at the beginning). This can be performed empirically. This is a minor (~2%), but essentially free, speedup! + #### Linear layer with sparse input and blocked sparse output Let's go one step further. For now all linear layers had dense outputs, but we can consider a layer where each input is connected only to a subset of outputs. We can consider the weights to be 0 where no connection is present. To make it possible to implement efficiently with vectorization in mind we have to zero out whole blocks of weights. A 16x128 Weight matrix with 2 non-zero 1x16 blocks per input may look like this for example: