几个PopCount函数的实现

从 《Beautiful Code》这本书里面看到的一章《The Quest for an Accelerated Population Count》by Henrry S.Warren, Jr.他也是《Hacker Delight》的作者,那本书里面也收集了各种计算技巧,有时间可以拿来翻翻。这篇文章讲的就是如何计算一个整数中bit=1的数量。

UPDATE: 文章最后增加了性能对比,包括了 `__builtin_popcount` 的性能。


最简单的写法是循环32次,稍微好点的做法是提前判断是否为0,但是不知道branch predication的副作用有多大。如果值范围是可以固定的话,那么最好还是使用固定循环次数的写法,这样会更加有时间保证。

uint32_t popcount11(uint32_t x) {
    uint32_t ans = 0;
    while (x) {
        ans += x & 0x1;
        x = x >> 1;
    }
    return ans;
}

UPDATE: 其实可以换成 `x=x&(x-1)` 这样会更快,另外一个方式是使用表查询,效果好像比这个要更好。TABLE大小是 32 * 4 = 128 字节,占用两个cache line(64字节), 在内存访问效率上应该是可以的。

/*
data = []
for i in range(0,256,8):
    value = 0
    for j in reversed(range(8)):
        value = (value << 4) | popcount(i+j)
    data.append(value)
*/

uint32_t TABLE[] = {841031952,  1127363105, 1127363105, 1413694258, 1127363105,
                    1413694258, 1413694258, 1700025411, 1127363105, 1413694258,
                    1413694258, 1700025411, 1413694258, 1700025411, 1700025411,
                    1986356564, 1127363105, 1413694258, 1413694258, 1700025411,
                    1413694258, 1700025411, 1700025411, 1986356564, 1413694258,
                    1700025411, 1700025411, 1986356564, 1700025411, 1986356564,
                    1986356564, 2272687717};

inline uint32_t GET8(unsigned char x) {
    return (TABLE[x >> 3] >> ((x & 0x7) << 2)) & 0xf;
}

uint32_t popcount01(uint32_t x) {
    return GET8(x & 0xff) + GET8((x >> 8) & 0xff) + GET8((x >> 16) & 0xff) +
           GET8((x >> 24) & 0xff);
}

如果采用分治思想的话,那么可以写成下面这样的代码,好处是没有循环分支,并且指令数量也更少了。

uint32_t _popcount21(uint32_t x) {
    x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
    x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
    x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
    x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
    x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
    return x;
}

上面那个版本,其实和下面这个版本是等价的,好处是涉及到的常量少了,可能指令会更加精简。

uint32_t popcount21(uint32_t x) {
    x = (x & 0x55555555) + ((x >> 1) & 0x55555555);
    x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
    x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);
    x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff);
    x = (x & 0x0000ffff) + ((x >> 16) & 0x0000ffff);
    return x;
}

但是如果仔细观察的话,可以发现从 `x>>4` 这里开始,其实相加就已经不会出现溢出了。因为high bits最多有4个1, low bits最多有4个1, 相加起来最多8个1, 完全可以放在4个bits里面,只不过最后我们需要在取个低位。所以上面的代码可以简化为下面这样

uint32_t __popcount21(uint32_t x) {
    // 这里可以假设分别是0,1的情况
    // 如果是11的话,那么11-01 = 10 = 2
    // 10 - 01 = 01 = 1
    // 0x 这个就是 x
    x = x - ((x >> 1) & 0x55555555);
    x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
    x = (x + (x >> 4)) & 0x0f0f0f0f;
    x = x + (x >> 8);
    x = x + (x >> 16);
    // 最后一次 low bits 最多 16, 就是 10000
    // high bits 最多 16,也是 10000
    // 所以最多就是 100000
    return x & 0x3f;
}

上面的思想可以扩展到两个数,以及4个数,只要在合适的机会下面将两个数直接相加就好。

uint32_t popcount22(uint32_t x, uint32_t y) {
    x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
    x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);

    y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1);
    y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2);

    x += y;
    x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
    x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
    x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
    return x;
}

uint32_t popcount24(uint32_t x, uint32_t y, uint32_t a, uint32_t b) {
    x = (x & 0x55555555) + ((x & 0xaaaaaaaa) >> 1);
    y = (y & 0x55555555) + ((y & 0xaaaaaaaa) >> 1);
    a = (a & 0x55555555) + ((a & 0xaaaaaaaa) >> 1);
    b = (b & 0x55555555) + ((b & 0xaaaaaaaa) >> 1);
    x = (x & 0x33333333) + ((x & 0xcccccccc) >> 2);
    y = (y & 0x33333333) + ((y & 0xcccccccc) >> 2);
    a = (a & 0x33333333) + ((a & 0xcccccccc) >> 2);
    b = (b & 0x33333333) + ((b & 0xcccccccc) >> 2);

    x += y;
    a += b;

    x = (x & 0x0f0f0f0f) + ((x & 0xf0f0f0f0) >> 4);
    a = (a & 0x0f0f0f0f) + ((a & 0xf0f0f0f0) >> 4);

    x += a;
    x = (x & 0x00ff00ff) + ((x & 0xff00ff00) >> 8);
    x = (x & 0x0000ffff) + ((x & 0xffff0000) >> 16);
    return x;
}

有了两个数的popcount求和,可以在上面做出扩展,比如求解 `pop(x) - pop(y)`, 这个式子可以变为 `pop(x) - (32 - pop(~y)) => pop(x) + pop(~y) - 32`

// pop(x) - pop(y) = pop(x) - (32 - pop(~y)) = pop(x) + pop(y) - 32
int popDiff(uint32_t x, uint32_t y) {
    x = x - ((x >> 1) & 0x55555555);
    x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
    y = ~y;
    y = y - ((y >> 1) & 0x55555555);
    y = (y & 0x33333333) + ((y >> 2) & 0x33333333);

    x += y;
    x = (x + (x >> 4)) & 0x0f0f0f0f;
    x = (x + (x >> 8));
    x = (x + (x >> 16));
    return x & 0x0000007f - 32;
}

此外还有个高效实现来比较较两个数的popcount,首先使用bits进行抵消,然后不断地去clear lsb, 然后看谁先为0.

int popCompare(uint32_t xp, uint32_t yp) {
    unsigned x, y;
    x = xp & ~yp;
    y = yp & ~xp;
    while (1) {
        // if y == 0 then 0
        // else < 0
        if (x == 0) return y | -y;
        if (y == 0) return 1;
        x = x & (x - 1);  // clear lsb
        y = y & (y - 1);
    }
}

还有使用avx512 vpopcount dq指令的实现,因为我的CPU不支持,所以也没有运行,不知道实现是否正确以及效果如何。

// don't use it. I don't have any cpu support avx512 vpopcnt dq.
// https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html
// g++ mm.cpp -g -W -Wall -mavx512f -mavx512vpopcntdq
uint32_t avx512_vpopcnt(const uint32_t* data, size_t size) {
    uint32_t ans = 0;
    uint64_t start = (uint64_t)data;
    if ((start % 64) != 0) {
        size_t rem = (start % 64) / 4;
        start = (start + 63) / 64 * 64;
        size -= rem;
        FORI(i, rem) ans += popcount21(data[i]);
    }

    const uint8_t* ptr = (uint8_t*)start;
    const uint8_t* end = ptr + size;
    const size_t chunks = size / 64;

    // count using AVX512 registers
    __m512i accumulator = _mm512_setzero_si512();
    for (size_t i = 0; i < chunks; i++, ptr += 64) {
        // Note: a short chain of dependencies, likely unrolling will be needed.
        const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
        const __m512i p = _mm512_popcnt_epi64(v);

        accumulator = _mm512_add_epi64(accumulator, p);
    }

    // horizontal sum of a register
    uint64_t tmp[8] __attribute__((aligned(64)));
    _mm512_store_si512((__m512i*)tmp, accumulator);

    for (size_t i = 0; i < 8; i++) {
        ans += (uint32_t)tmp[i];
    }

    // popcount the tail
    while (ptr + 4 < end) {
        ans += popcount21(*(uint32_t*)(ptr));
        ptr += 4;
    }
    return ans;
}

下面是性能数据,代码可以看这里 这里

可以看到分治实现比内置实现效率还要高点

[level-2] N = 1000, took: 82ms, avg 82ns/N, ans = 443894796
[level-1] N = 1000, took: 106ms, avg 106ns/N, ans = 443894796
[level0] N = 1000, took: 337ms, avg 337ns/N, ans = 443894796
[level1] N = 1000, took: 55ms, avg 55ns/N, ans = 443894796
[level2] N = 1000, took: 37ms, avg 37ns/N, ans = 443894796
[level4] N = 1000, took: 32ms, avg 32ns/N, ans = 443894796