几个PopCount函数的实现
从 《Beautiful Code》这本书里面看到的一章《The Quest for an Accelerated Population Count》by Henrry S.Warren, Jr.他也是《Hacker Delight》的作者,那本书里面也收集了各种计算技巧,有时间可以拿来翻翻。这篇文章讲的就是如何计算一个整数中bit=1的数量。
UPDATE: 文章最后增加了性能对比,包括了 `__builtin_popcount` 的性能。
最简单的写法是循环32次,稍微好点的做法是提前判断是否为0,但是不知道branch predication的副作用有多大。如果值范围是可以固定的话,那么最好还是使用固定循环次数的写法,这样会更加有时间保证。
uint32_t popcount11(uint32_t x) { uint32_t ans = 0; while (x) { ans += x & 0x1; x = x >> 1; } return ans; }
UPDATE: 其实可以换成 `x=x&(x-1)` 这样会更快,另外一个方式是使用表查询,效果好像比这个要更好。TABLE大小是 32 * 4 = 128 字节,占用两个cache line(64字节), 在内存访问效率上应该是可以的。
/* data = [] for i in range(0,256,8): value = 0 for j in reversed(range(8)): value = (value << 4) | popcount(i+j) data.append(value) */ uint32_t TABLE[] = {841031952, 1127363105, 1127363105, 1413694258, 1127363105, 1413694258, 1413694258, 1700025411, 1127363105, 1413694258, 1413694258, 1700025411, 1413694258, 1700025411, 1700025411, 1986356564, 1127363105, 1413694258, 1413694258, 1700025411, 1413694258, 1700025411, 1700025411, 1986356564, 1413694258, 1700025411, 1700025411, 1986356564, 1700025411, 1986356564, 1986356564, 2272687717}; inline uint32_t GET8(unsigned char x) { return (TABLE[x >> 3] >> ((x & 0x7) << 2)) & 0xf; } uint32_t popcount01(uint32_t x) { return GET8(x & 0xff) + GET8((x >> 8) & 0xff) + GET8((x >> 16) & 0xff) + GET8((x >> 24) & 0xff); }
如果采用分治思想的话,那么可以写成下面这样的代码,好处是没有循环分支,并且指令数量也更少了。
uint32_t _popcount21(uint32_t x) { x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1); x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2); x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4); x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8); x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16); return x; }
上面那个版本,其实和下面这个版本是等价的,好处是涉及到的常量少了,可能指令会更加精简。
uint32_t popcount21(uint32_t x) { x = (x & 0x55555555) + ((x >> 1) & 0x55555555); x = (x & 0x33333333) + ((x >> 2) & 0x33333333); x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f); x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff); x = (x & 0x0000ffff) + ((x >> 16) & 0x0000ffff); return x; }
但是如果仔细观察的话,可以发现从 `x>>4` 这里开始,其实相加就已经不会出现溢出了。因为high bits最多有4个1, low bits最多有4个1, 相加起来最多8个1, 完全可以放在4个bits里面,只不过最后我们需要在取个低位。所以上面的代码可以简化为下面这样
uint32_t __popcount21(uint32_t x) { // 这里可以假设分别是0,1的情况 // 如果是11的话,那么11-01 = 10 = 2 // 10 - 01 = 01 = 1 // 0x 这个就是 x x = x - ((x >> 1) & 0x55555555); x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2); x = (x + (x >> 4)) & 0x0f0f0f0f; x = x + (x >> 8); x = x + (x >> 16); // 最后一次 low bits 最多 16, 就是 10000 // high bits 最多 16,也是 10000 // 所以最多就是 100000 return x & 0x3f; }
上面的思想可以扩展到两个数,以及4个数,只要在合适的机会下面将两个数直接相加就好。
uint32_t popcount22(uint32_t x, uint32_t y) { x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1); x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2); y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1); y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2); x += y; x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4); x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8); x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16); return x; } uint32_t popcount24(uint32_t x, uint32_t y, uint32_t a, uint32_t b) { x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1); y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1); a = (a & 0x55555555) + ((a & 0xaaaaaaaa) >> 1); b = (b & 0x55555555) + ((b & 0xaaaaaaaa) >> 1); x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2); y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2); a = (a & 0x33333333) + ((a & 0xcccccccc) >> 2); b = (b & 0x33333333) + ((b & 0xcccccccc) >> 2); x += y; a += b; x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4); a = (a & 0x0f0f0f0f) + ((a & 0xf0f0f0f0) >> 4); x += a; x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8); x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16); return x; }
有了两个数的popcount求和,可以在上面做出扩展,比如求解 `pop(x) - pop(y)`, 这个式子可以变为 `pop(x) - (32 - pop(~y)) => pop(x) + pop(~y) - 32`
// pop(x) - pop(y) = pop(x) - (32 - pop(~y)) = pop(x) + pop(y) - 32 int popDiff(uint32_t x, uint32_t y) { x = x - ((x >> 1) & 0x55555555); x = (x & 0x33333333) + ((x >> 2) & 0x33333333); y = ~y; y = y - ((y >> 1) & 0x55555555); y = (y & 0x33333333) + ((y >> 2) & 0x33333333); x += y; x = (x + (x >> 4)) & 0x0f0f0f0f; x = (x + (x >> 8)); x = (x + (x >> 16)); return x & 0x0000007f - 32; }
此外还有个高效实现来比较较两个数的popcount,首先使用bits进行抵消,然后不断地去clear lsb, 然后看谁先为0.
int popCompare(uint32_t xp, uint32_t yp) { unsigned x, y; x = xp & ~yp; y = yp & ~xp; while (1) { // if y == 0 then 0 // else < 0 if (x == 0) return y | -y; if (y == 0) return 1; x = x & (x - 1); // clear lsb y = y & (y - 1); } }
还有使用avx512 vpopcount dq指令的实现,因为我的CPU不支持,所以也没有运行,不知道实现是否正确以及效果如何。
// don't use it. I don't have any cpu support avx512 vpopcnt dq. // https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html // g++ mm.cpp -g -W -Wall -mavx512f -mavx512vpopcntdq uint32_t avx512_vpopcnt(const uint32_t* data, size_t size) { uint32_t ans = 0; uint64_t start = (uint64_t)data; if ((start % 64) != 0) { size_t rem = (start % 64) / 4; start = (start + 63) / 64 * 64; size -= rem; FORI(i, rem) ans += popcount21(data[i]); } const uint8_t* ptr = (uint8_t*)start; const uint8_t* end = ptr + size; const size_t chunks = size / 64; // count using AVX512 registers __m512i accumulator = _mm512_setzero_si512(); for (size_t i = 0; i < chunks; i++, ptr += 64) { // Note: a short chain of dependencies, likely unrolling will be needed. const __m512i v = _mm512_loadu_si512((const __m512i*)ptr); const __m512i p = _mm512_popcnt_epi64(v); accumulator = _mm512_add_epi64(accumulator, p); } // horizontal sum of a register uint64_t tmp[8] __attribute__((aligned(64))); _mm512_store_si512((__m512i*)tmp, accumulator); for (size_t i = 0; i < 8; i++) { ans += (uint32_t)tmp[i]; } // popcount the tail while (ptr + 4 < end) { ans += popcount21(*(uint32_t*)(ptr)); ptr += 4; } return ans; }
下面是性能数据,代码可以看这里 这里
- level-2: `__builtin_popcount` 实现
- level-1: 打表实现
- level0: 循环移位实现
- level1,2,4: 分治算法实现
可以看到分治实现比内置实现效率还要高点
[level-2] N = 1000, took: 82ms, avg 82ns/N, ans = 443894796 [level-1] N = 1000, took: 106ms, avg 106ns/N, ans = 443894796 [level0] N = 1000, took: 337ms, avg 337ns/N, ans = 443894796 [level1] N = 1000, took: 55ms, avg 55ns/N, ans = 443894796 [level2] N = 1000, took: 37ms, avg 37ns/N, ans = 443894796 [level4] N = 1000, took: 32ms, avg 32ns/N, ans = 443894796