SIMD prefix sum实现

最近实现parquet data page v2里面几个encoding https://parquet.apache.org/docs/file-format/data-pages/encodings/, 其中delta binary packed encoding需要实现prefix sum.

直觉上觉得这个东西应该是有SIMD的实现,但是不知道怎么写,后来找到了两个还不错的源:

两个link里面分别给出了avx2和avx512的实现,按照avx512的实现我仿照写了一个256bit的实现,以为这样写avx2会更加简单高效,后来查指令表格发现用到了avx512f+avx512L =D.

几个实现我放到了 [[Enhancement] add parquet DELTA_BINARY_PACKED encoding benchmark by dirtysalt · Pull Request #58470 · StarRocks/starrocks](https://github.com/StarRocks/starrocks/pull/58470) 这个PR里面,我这里也留个代码备份。

浏览代码的时候发现gcc对x86 arch支持multi version static function,可以针对不同的架构做不同的实现,然后可以自动进行runtime dispatch. 此外如果使用类似 `__attribute__((target("avx512f,avx512vl")))` 这样的pragma的话,那么编译选项里面不用编写 "-mavx512f -mavx512vl"的话也是可以进行编译的。只不过头文件需要的话需要把所有的intrinsic header files都包含进来,比如下面这样

// Only x86 support function multiversion.
// https://gcc.gnu.org/wiki/FunctionMultiVersioning
// TODO(GoHalo) Support aarch64 platform.
#if defined(__GNUC__) && defined(__x86_64__)
#include <x86intrin.h>

#define MFV_IMPL(IMPL, ATTR)                                                               \
    _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-Wunused-function\"") \
            ATTR static inline IMPL _Pragma("GCC diagnostic pop")

#define MFV_SSE42(IMPL) MFV_IMPL(IMPL, __attribute__((target("sse4.2"))))
#define MFV_AVX2(IMPL) MFV_IMPL(IMPL, __attribute__((target("avx2"))))
#define MFV_AVX512(IMPL) MFV_IMPL(IMPL, __attribute__((target("avx512f,avx512bw"))))
#define MFV_DEFAULT(IMPL) MFV_IMPL(IMPL, __attribute__((target("default"))))

#else

#define MFV_SSE42(IMPL)
#define MFV_AVX2(IMPL)
#define MFV_AVX512(IMPL)
#define MFV_DEFAULT(IMPL) IMPL

#endif

然后我对这几个实现做了下benchmark,做的比较简单,都是在固定大小(4096)个int32数组上做测试:

我的直觉是如果数组比较小的话,可能差距没有这么大。如果是数组比较大的话,avx512>avx2x>avx2.因为那个avx2实现需要有2 passes, 可能cache locality不是特别好。

链接里面给出了解决方法,就是针对一个Block做2 passes, 感觉这样对cache会更好。

---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
BM_int32_avx512_prefix_sum/4096         505 ns          505 ns      1387154
BM_int32_avx512_prefix_sum/8192        1017 ns         1017 ns       688849
BM_int32_avx512_prefix_sum/16384       2336 ns         2335 ns       299731
BM_int32_avx512_prefix_sum/32768       4665 ns         4665 ns       149991
BM_int32_avx2_prefix_sum/4096           973 ns          973 ns       719452
BM_int32_avx2_prefix_sum/8192          1934 ns         1934 ns       361978
BM_int32_avx2_prefix_sum/16384         4204 ns         4204 ns       166237
BM_int32_avx2_prefix_sum/32768         8399 ns         8398 ns        83391
BM_int32_avx2x_prefix_sum/4096          804 ns          803 ns       871364
BM_int32_avx2x_prefix_sum/8192         1607 ns         1607 ns       435556
BM_int32_avx2x_prefix_sum/16384        3208 ns         3207 ns       218231
BM_int32_avx2x_prefix_sum/32768        6411 ns         6411 ns       109205
BM_int32_native_prefix_sum/4096        1671 ns         1670 ns       419114
BM_int32_native_prefix_sum/8192        3340 ns         3339 ns       209865
BM_int32_native_prefix_sum/16384       6684 ns         6683 ns       104963
BM_int32_native_prefix_sum/32768      13336 ns        13334 ns        52479
BM_int64_avx512_prefix_sum/4096         980 ns          980 ns       712131
BM_int64_avx512_prefix_sum/8192        1942 ns         1942 ns       360513
BM_int64_avx512_prefix_sum/16384       3882 ns         3881 ns       180216
BM_int64_avx512_prefix_sum/32768       7760 ns         7759 ns        90254
BM_int64_native_prefix_sum/4096        1671 ns         1671 ns       418874
BM_int64_native_prefix_sum/8192        3462 ns         3461 ns       202279
BM_int64_native_prefix_sum/16384       6910 ns         6909 ns       101323
BM_int64_native_prefix_sum/32768      13805 ns        13804 ns        50588

下面是avx2的实现代码

// =========================
// reference: https://en.algorithmica.org/hpc/algorithms/prefix/
// int32 / uint32_t version
__attribute__((target("avx2"))) static inline void delta_decode_chain_int32_avx2(int32_t* buf, int n, int32_t min_delta,
                                                                                 int32_t& last_value) {
    using v4i = __m128i;
    using v8i = __m256i;

    // avx2 instructions.
    const v8i v_min_delta = _mm256_set1_epi32(min_delta);
    auto prefix = [&](int32_t* p) {
        v8i x = _mm256_loadu_si256((v8i*)p);
        x = _mm256_add_epi32(x, v_min_delta);
        x = _mm256_add_epi32(x, _mm256_slli_si256(x, 4));
        x = _mm256_add_epi32(x, _mm256_slli_si256(x, 8));
        _mm256_storeu_si256((v8i*)p, x);
    };

    // sse2 instructions.
    auto accumulate = [](int32_t* p, v4i s) {
        v4i x = _mm_loadu_si128((v4i*)p);
        x = _mm_add_epi32(s, x);
        _mm_storeu_si128((v4i*)p, x);
        return _mm_shuffle_epi32(x, _MM_SHUFFLE(3, 3, 3, 3));
    };

    int sz = (n / 8) * 8;
    if (sz > 0) {
        // two passes, don't mixed use avx2 and sse2.
        for (int i = 0; i < sz; i += 8) {
            prefix(buf + i);
        }
        v4i s = _mm_set1_epi32(last_value);
        for (int i = 0; i < sz; i += 4) {
            s = accumulate(buf + i, s);
        }
        // any index is ok.
        last_value = _mm_extract_epi32(s, 0);
    }

    for (int i = sz; i < n; i++) {
        buf[i] += last_value + min_delta;
        last_value = buf[i];
    }
}

下面这个是avx512的实现代码

__attribute__((target("avx512f"))) static inline __m512i prefix_and_accumulate_int32_avx512(int32_t* p, __m512i s,
                                                                                            const __m512i& v_min_delta,
                                                                                            const __m512i& v_zero,
                                                                                            const __m512i& v_perm15) {
    // prefix
    __m512i x = _mm512_loadu_si512(p);
    x = _mm512_add_epi32(x, v_min_delta);
    x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 1));
    x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 2));
    x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 4));
    x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 8));
    // accumulate
    x = _mm512_add_epi32(s, x);
    _mm512_storeu_si512((__m512i*)p, x);
    // return last value.
    return _mm512_permutexvar_epi32(v_perm15, x);
};

// reference: https://www.adms-conf.org/2020-camera-ready/ADMS20_05.pdf
__attribute__((target("avx512f"))) static inline void delta_decode_chain_int32_avx512(int32_t* buf, int n,
                                                                                      int32_t min_delta,
                                                                                      int32_t& last_value) {
    using v4i = __m128i;
    using v8i = __m256i;
    using v16i = __m512i;

    // avx512 instructions.
    const v16i v_min_delta = _mm512_set1_epi32(min_delta);
    const v16i v_zero = _mm512_setzero_si512();
    const v16i v_perm15 = _mm512_set1_epi32(15);
    // auto prefix = [&](int32_t* p) {
    //     v16i x = _mm512_loadu_si512((v8i*)p);
    //     x = _mm512_add_epi32(x, v_min_delta);
    //     x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 1));
    //     x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 2));
    //     x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 4));
    //     x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 8));
    //     _mm512_storeu_si512((v16i*)p, x);
    // };

    // auto accumulate = [&](int32_t* p, v16i s) {
    //     v16i x = _mm512_loadu_si512((v16i*)p);
    //     x = _mm512_add_epi32(s, x);
    //     _mm512_storeu_si512((v16i*)p, x);
    //     return _mm512_permutexvar_epi32(v_perm15, x);
    // };

    int sz = (n / 16) * 16;
    if (sz > 0) {
        // for (int i = 0; i < sz; i += 16) {
        //     prefix(buf + i);
        // }
        // v16i s = _mm512_set1_epi32(last_value);
        // for (int i = 0; i < sz; i += 16) {
        //     s = accumulate(buf + i, s);
        // }

        v16i s = _mm512_set1_epi32(last_value);
        for (int i = 0; i < sz; i += 16) {
            s = prefix_and_accumulate_int32_avx512(buf + i, s, v_min_delta, v_zero, v_perm15);
        }
        v4i s2 = _mm512_castsi512_si128(s);
        last_value = _mm_extract_epi32(s2, 0);
    }

    for (int i = sz; i < n; i++) {
        buf[i] += last_value + min_delta;
        last_value = buf[i];
    }
}

然后这个是我以为可以运行在avx2, 但是实际上使用到了avx512f+avx512l的代码实现

// Though we handle 256bit as a unit, we still use some instructions of avx512f + avx512vl.
__attribute__((target("avx512f,avx512vl"))) static inline __m256i prefix_and_accumulate_int32_avx2(
        int32_t* p, __m256i s, const __m256i& v_min_delta, const __m256i& v_zero, const __m256i& v_perm7) {
    __m256i x = _mm256_loadu_si256((__m256i*)p);
    x = _mm256_add_epi32(x, v_min_delta);
    x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 1));
    x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 2));
    x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 4));
    // accumulate
    x = _mm256_add_epi32(s, x);
    _mm256_storeu_si256((__m256i*)p, x);
    // return last value.
    return _mm256_permutevar8x32_epi32(x, v_perm7);
}

__attribute__((target("avx2,avx512f,avx512vl"))) static inline void delta_decode_chain_int32_avx2x(
        int32_t* buf, int n, int32_t min_delta, int32_t& last_value) {
    using v4i = __m128i;
    using v8i = __m256i;

    // avx2 instructions.
    const v8i v_min_delta = _mm256_set1_epi32(min_delta);
    const v8i v_zero = _mm256_setzero_si256();
    const v8i v_perm7 = _mm256_set1_epi32(7);

    int sz = (n / 8) * 8;
    if (sz > 0) {
        v8i s = _mm256_set1_epi32(last_value);
        for (int i = 0; i < sz; i += 8) {
            s = prefix_and_accumulate_int32_avx2(buf + i, s, v_min_delta, v_zero, v_perm7);
        }
        last_value = _mm256_extract_epi32(s, 0);
    }

    for (int i = sz; i < n; i++) {
        buf[i] += last_value + min_delta;
        last_value = buf[i];
    }
}