SIMD prefix sum实现
最近实现parquet data page v2里面几个encoding https://parquet.apache.org/docs/file-format/data-pages/encodings/, 其中delta binary packed encoding需要实现prefix sum.
直觉上觉得这个东西应该是有SIMD的实现,但是不知道怎么写,后来找到了两个还不错的源:
- Prefix Sum with SIMD - Algorithmica
- https://www.adms-conf.org/2020-camera-ready/ADMS20_05.pdf (parallel prefix with simd)
两个link里面分别给出了avx2和avx512的实现,按照avx512的实现我仿照写了一个256bit的实现,以为这样写avx2会更加简单高效,后来查指令表格发现用到了avx512f+avx512L =D.
几个实现我放到了 [[Enhancement] add parquet DELTA_BINARY_PACKED encoding benchmark by dirtysalt · Pull Request #58470 · StarRocks/starrocks](https://github.com/StarRocks/starrocks/pull/58470) 这个PR里面,我这里也留个代码备份。
浏览代码的时候发现gcc对x86 arch支持multi version static function,可以针对不同的架构做不同的实现,然后可以自动进行runtime dispatch. 此外如果使用类似 `__attribute__((target("avx512f,avx512vl")))` 这样的pragma的话,那么编译选项里面不用编写 "-mavx512f -mavx512vl"的话也是可以进行编译的。只不过头文件需要的话需要把所有的intrinsic header files都包含进来,比如下面这样
// Only x86 support function multiversion. // https://gcc.gnu.org/wiki/FunctionMultiVersioning // TODO(GoHalo) Support aarch64 platform. #if defined(__GNUC__) && defined(__x86_64__) #include <x86intrin.h> #define MFV_IMPL(IMPL, ATTR) \ _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-Wunused-function\"") \ ATTR static inline IMPL _Pragma("GCC diagnostic pop") #define MFV_SSE42(IMPL) MFV_IMPL(IMPL, __attribute__((target("sse4.2")))) #define MFV_AVX2(IMPL) MFV_IMPL(IMPL, __attribute__((target("avx2")))) #define MFV_AVX512(IMPL) MFV_IMPL(IMPL, __attribute__((target("avx512f,avx512bw")))) #define MFV_DEFAULT(IMPL) MFV_IMPL(IMPL, __attribute__((target("default")))) #else #define MFV_SSE42(IMPL) #define MFV_AVX2(IMPL) #define MFV_AVX512(IMPL) #define MFV_DEFAULT(IMPL) IMPL #endif
然后我对这几个实现做了下benchmark,做的比较简单,都是在固定大小(4096)个int32数组上做测试:
- avx512 505ns
- avx2 973ns
- avx2x (实际使用了avx512f+vl) 804ns
- native (+prefetch) 1671ns
我的直觉是如果数组比较小的话,可能差距没有这么大。如果是数组比较大的话,avx512>avx2x>avx2.因为那个avx2实现需要有2 passes, 可能cache locality不是特别好。
链接里面给出了解决方法,就是针对一个Block做2 passes, 感觉这样对cache会更好。
--------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------- BM_int32_avx512_prefix_sum/4096 505 ns 505 ns 1387154 BM_int32_avx512_prefix_sum/8192 1017 ns 1017 ns 688849 BM_int32_avx512_prefix_sum/16384 2336 ns 2335 ns 299731 BM_int32_avx512_prefix_sum/32768 4665 ns 4665 ns 149991 BM_int32_avx2_prefix_sum/4096 973 ns 973 ns 719452 BM_int32_avx2_prefix_sum/8192 1934 ns 1934 ns 361978 BM_int32_avx2_prefix_sum/16384 4204 ns 4204 ns 166237 BM_int32_avx2_prefix_sum/32768 8399 ns 8398 ns 83391 BM_int32_avx2x_prefix_sum/4096 804 ns 803 ns 871364 BM_int32_avx2x_prefix_sum/8192 1607 ns 1607 ns 435556 BM_int32_avx2x_prefix_sum/16384 3208 ns 3207 ns 218231 BM_int32_avx2x_prefix_sum/32768 6411 ns 6411 ns 109205 BM_int32_native_prefix_sum/4096 1671 ns 1670 ns 419114 BM_int32_native_prefix_sum/8192 3340 ns 3339 ns 209865 BM_int32_native_prefix_sum/16384 6684 ns 6683 ns 104963 BM_int32_native_prefix_sum/32768 13336 ns 13334 ns 52479 BM_int64_avx512_prefix_sum/4096 980 ns 980 ns 712131 BM_int64_avx512_prefix_sum/8192 1942 ns 1942 ns 360513 BM_int64_avx512_prefix_sum/16384 3882 ns 3881 ns 180216 BM_int64_avx512_prefix_sum/32768 7760 ns 7759 ns 90254 BM_int64_native_prefix_sum/4096 1671 ns 1671 ns 418874 BM_int64_native_prefix_sum/8192 3462 ns 3461 ns 202279 BM_int64_native_prefix_sum/16384 6910 ns 6909 ns 101323 BM_int64_native_prefix_sum/32768 13805 ns 13804 ns 50588
下面是avx2的实现代码
// ========================= // reference: https://en.algorithmica.org/hpc/algorithms/prefix/ // int32 / uint32_t version __attribute__((target("avx2"))) static inline void delta_decode_chain_int32_avx2(int32_t* buf, int n, int32_t min_delta, int32_t& last_value) { using v4i = __m128i; using v8i = __m256i; // avx2 instructions. const v8i v_min_delta = _mm256_set1_epi32(min_delta); auto prefix = [&](int32_t* p) { v8i x = _mm256_loadu_si256((v8i*)p); x = _mm256_add_epi32(x, v_min_delta); x = _mm256_add_epi32(x, _mm256_slli_si256(x, 4)); x = _mm256_add_epi32(x, _mm256_slli_si256(x, 8)); _mm256_storeu_si256((v8i*)p, x); }; // sse2 instructions. auto accumulate = [](int32_t* p, v4i s) { v4i x = _mm_loadu_si128((v4i*)p); x = _mm_add_epi32(s, x); _mm_storeu_si128((v4i*)p, x); return _mm_shuffle_epi32(x, _MM_SHUFFLE(3, 3, 3, 3)); }; int sz = (n / 8) * 8; if (sz > 0) { // two passes, don't mixed use avx2 and sse2. for (int i = 0; i < sz; i += 8) { prefix(buf + i); } v4i s = _mm_set1_epi32(last_value); for (int i = 0; i < sz; i += 4) { s = accumulate(buf + i, s); } // any index is ok. last_value = _mm_extract_epi32(s, 0); } for (int i = sz; i < n; i++) { buf[i] += last_value + min_delta; last_value = buf[i]; } }
下面这个是avx512的实现代码
__attribute__((target("avx512f"))) static inline __m512i prefix_and_accumulate_int32_avx512(int32_t* p, __m512i s, const __m512i& v_min_delta, const __m512i& v_zero, const __m512i& v_perm15) { // prefix __m512i x = _mm512_loadu_si512(p); x = _mm512_add_epi32(x, v_min_delta); x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 1)); x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 2)); x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 4)); x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 8)); // accumulate x = _mm512_add_epi32(s, x); _mm512_storeu_si512((__m512i*)p, x); // return last value. return _mm512_permutexvar_epi32(v_perm15, x); }; // reference: https://www.adms-conf.org/2020-camera-ready/ADMS20_05.pdf __attribute__((target("avx512f"))) static inline void delta_decode_chain_int32_avx512(int32_t* buf, int n, int32_t min_delta, int32_t& last_value) { using v4i = __m128i; using v8i = __m256i; using v16i = __m512i; // avx512 instructions. const v16i v_min_delta = _mm512_set1_epi32(min_delta); const v16i v_zero = _mm512_setzero_si512(); const v16i v_perm15 = _mm512_set1_epi32(15); // auto prefix = [&](int32_t* p) { // v16i x = _mm512_loadu_si512((v8i*)p); // x = _mm512_add_epi32(x, v_min_delta); // x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 1)); // x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 2)); // x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 4)); // x = _mm512_add_epi32(x, _mm512_alignr_epi32(x, v_zero, 16 - 8)); // _mm512_storeu_si512((v16i*)p, x); // }; // auto accumulate = [&](int32_t* p, v16i s) { // v16i x = _mm512_loadu_si512((v16i*)p); // x = _mm512_add_epi32(s, x); // _mm512_storeu_si512((v16i*)p, x); // return _mm512_permutexvar_epi32(v_perm15, x); // }; int sz = (n / 16) * 16; if (sz > 0) { // for (int i = 0; i < sz; i += 16) { // prefix(buf + i); // } // v16i s = _mm512_set1_epi32(last_value); // for (int i = 0; i < sz; i += 16) { // s = accumulate(buf + i, s); // } v16i s = _mm512_set1_epi32(last_value); for (int i = 0; i < sz; i += 16) { s = prefix_and_accumulate_int32_avx512(buf + i, s, v_min_delta, v_zero, v_perm15); } v4i s2 = _mm512_castsi512_si128(s); last_value = _mm_extract_epi32(s2, 0); } for (int i = sz; i < n; i++) { buf[i] += last_value + min_delta; last_value = buf[i]; } }
然后这个是我以为可以运行在avx2, 但是实际上使用到了avx512f+avx512l的代码实现
// Though we handle 256bit as a unit, we still use some instructions of avx512f + avx512vl. __attribute__((target("avx512f,avx512vl"))) static inline __m256i prefix_and_accumulate_int32_avx2( int32_t* p, __m256i s, const __m256i& v_min_delta, const __m256i& v_zero, const __m256i& v_perm7) { __m256i x = _mm256_loadu_si256((__m256i*)p); x = _mm256_add_epi32(x, v_min_delta); x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 1)); x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 2)); x = _mm256_add_epi32(x, _mm256_alignr_epi32(x, v_zero, 8 - 4)); // accumulate x = _mm256_add_epi32(s, x); _mm256_storeu_si256((__m256i*)p, x); // return last value. return _mm256_permutevar8x32_epi32(x, v_perm7); } __attribute__((target("avx2,avx512f,avx512vl"))) static inline void delta_decode_chain_int32_avx2x( int32_t* buf, int n, int32_t min_delta, int32_t& last_value) { using v4i = __m128i; using v8i = __m256i; // avx2 instructions. const v8i v_min_delta = _mm256_set1_epi32(min_delta); const v8i v_zero = _mm256_setzero_si256(); const v8i v_perm7 = _mm256_set1_epi32(7); int sz = (n / 8) * 8; if (sz > 0) { v8i s = _mm256_set1_epi32(last_value); for (int i = 0; i < sz; i += 8) { s = prefix_and_accumulate_int32_avx2(buf + i, s, v_min_delta, v_zero, v_perm7); } last_value = _mm256_extract_epi32(s, 0); } for (int i = sz; i < n; i++) { buf[i] += last_value + min_delta; last_value = buf[i]; } }