-
Notifications
You must be signed in to change notification settings - Fork 0
/
counter.cpp
81 lines (64 loc) · 3.06 KB
/
counter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include "defs.hpp"
#include <immintrin.h>
#include <cstdint>
#include <vector>
#include <future>
// @powturbo's code with slight modifications
inline uint64_t opt_count(const char *s, const char *e, const char c) {
const __m256i cv = _mm256_set1_epi8(c), zv = _mm256_setzero_si256();
__m256i sum = zv, acr0, acr1, acr2, acr3;
const char *pe;
while (s != e - (e - s) % (252 * 32)) {
for (acr0 = acr1 = acr2 = acr3 = zv, pe = s + 252 * 32; s != pe; s += 128) {
acr0 = _mm256_sub_epi8(acr0, _mm256_cmpeq_epi8(cv, _mm256_load_si256((const __m256i *) s)));
acr1 = _mm256_sub_epi8(acr1, _mm256_cmpeq_epi8(cv, _mm256_load_si256((const __m256i *) (s + 32))));
acr2 = _mm256_sub_epi8(acr2, _mm256_cmpeq_epi8(cv, _mm256_load_si256((const __m256i *) (s + 64))));
acr3 = _mm256_sub_epi8(acr3, _mm256_cmpeq_epi8(cv, _mm256_load_si256((const __m256i *) (s + 96))));
_mm_prefetch(s + 1024, _MM_HINT_T0);
}
sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr0, zv));
sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr1, zv));
sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr2, zv));
sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr3, zv));
}
for (acr0 = zv; s + 32 < e; s += 32)
acr0 = _mm256_sub_epi8(acr0, _mm256_cmpeq_epi8(cv, _mm256_load_si256((const __m256i *) s)));
sum = _mm256_add_epi64(sum, _mm256_sad_epu8(acr0, zv));
uint64_t count =
_mm256_extract_epi64(sum, 0)
+ _mm256_extract_epi64(sum, 1)
+ _mm256_extract_epi64(sum, 2)
+ _mm256_extract_epi64(sum, 3);
// Using != is unsafe, use a stricter check
while(s < e)
count += *s++ == c;
return count;
}
uint64_t opt_count_parallel(const char *begin, const char *end, const char target, bool singleThreaded) noexcept {
if (singleThreaded)
return opt_count(begin, end, target);
const unsigned int num_threads = std::thread::hardware_concurrency();
const size_t total_length = end - begin;
// FIXME: Don't multiply by 1000 when used with verifier
if (total_length < num_threads * 1000)
return opt_count(begin, end, target);
const size_t chunk_size = (total_length + num_threads - 1) / num_threads;
std::vector<std::future<uint64_t>> futures(num_threads);
uint64_t total_count = 0;
for (unsigned int i = 0; i < num_threads; ++i) {
const char *chunk_begin = begin + i * chunk_size;
const char *chunk_end = std::min(end, chunk_begin + chunk_size);
futures[i] = std::async(std::launch::async, [&total_count, chunk_begin, chunk_end, target] {
return opt_count(chunk_begin, chunk_end, target);
});
}
for (auto &future : futures)
total_count += future.get();
return total_count;
}
uint64_t opt_count_parallel(const char *begin, const char *end, const int target, bool singleThreaded) noexcept {
// Horrible code
if (target >= 0 && target <= 127)
return opt_count_parallel(begin, end, static_cast<const char>(target), singleThreaded);
return 0;
}