FizzBuzz程序优化

这个程序比赛是从之前帖子看过来的,https://codegolf.stackexchange.com/questions/215216/high-throughput-fizz-buzz/236630


我测试了一下第一名ASM编写的程序在我们开发机器上的效果,大约只能跑到 6.6GiB/s. 代码在 这里. 帖子里面给出的分数是 60GiB/s, 感觉这个差距有点大。

帖子里面给的配置是 "Scores are from running on my desktop with an AMD 5950x CPU (16C / 32T). I have 32GB of 3600Mhz RAM."

我看了一下我们机器配置大约是

我也不确定这些配置是否会造成有这么大的差距??

现在还有个问题是,我的程序大部分时候运行很好,但是有时候却下降到500MB/s. 而第一名程序则没有这个问题,运行的是非常的稳定。这个需要在分析分析, 可能是当时有大量的内存写操作,但是好像从top上面看不到。或许后面需要使用perf看看系统在做什么事情。

UPDATE: 我自己感觉好像是kernel pipe buffer不够的原因:我使用 vmsplice 模式的话似乎不受影响,但是使用 write 模式的话影响就很大。


我的代码在 这里, 大约能跑到4.6GiB/s左右,也想不出怎么继续优化了。这里列举下优化点吧:

  1. 不要每次使用去使用itoa去计算整数的表示,这个可以通过模拟累加来完成
  2. 按照30步长进行循环展开:每15个是一轮FizzBuzz循环,而每10个一轮则是为了方便累加,所以取30.
  3. 对于短串memcpy长度尽可能地round到8个字节,不然会出现许多 "取最后2字节" “取最后一个字节"的操作
  4. 减少syscall调用次数,尽可能地写到buffer上面
  5. 使用 vmsplice 调用来减少 user/kernel 之间的数据拷贝,前提是需要设置好pipe size. (大约从3.2GB/s -> 4.2GB/s)

每次进行+10的操作

inline int add10(char* end) {
    char* p = end - 1;
    for (;;) {
        if (*p == 0) {
            *p = '1';
            return 1;
        } else if (*p == '9') {
            *p = '0';
            p -= 1;
        } else {
            *p = *p + 1;
            break;
        }
    }
    return 0;
}

按照30步长进行循环展开,有三种情况需要分别展开,下面这个是第一步展开的情况。其中end是 `xxxxx1\n` 这样字符串格式, 比如我们处理到了10000, 那么end就是 `10001\n`. 我们每次拷贝模板,然后修改最后一个字符。

但是这里也不只是每个都拷贝模板然后修改字符,有时候可以将最后一个字符和后面常量字符串一起写入,比如写入"2\nFizz\n"这个case.

这里还有个优化,是假设已经copy了end2次的话,那么可以使用这个duplicated进行拷贝。

#define MC(x, c) buf = op<c, c>(buf, x)
// 使用dup模式
template <int digit>
char* output0(char* RE buf, const char* RE end) {
    // 1   2    3    4     5     6    7   8   9     10
    // 1   2  fizz   4   buzz  fizz   7   8   fizz  buzz
    const char* dup = buf;
    MC(end, digit + 1);
    MC(end, digit - 1);
    MC("2\nFizz\n0", 7);

    MC(end, digit - 1);
    MC("4\nBuzz\nFizz\n0000", 12);

    MC(dup, 2 * digit);
    *(buf - 1 - digit) = '7';
    MC("8\nFizz\nBuzz\n0000", 12);
    return buf;
}

// 不使用dup模式

template <int digit>
char* output0(char* RE buf, const char* RE end) {
    // 1   2    3    4     5     6    7   8   9      10
    // 1   2  fizz   4   buzz  fizz   7   8   fizz   buzz
    MC(end, digit + 1);
    MC(end, digit - 1);
    MC("2\nFizz\n\0", 7);

    MC(end, digit - 1);
    MC("4\nBuzz\nFizz\n0000", 12);

    MC(end, digit + 1);
    *(buf - 2) = '7';
    MC(end, digit - 1);
    MC("8\nFizz\nBuzz\n0000", 12);
    return buf;
}

对mempcy进行优化,我们在所有的原始串上增加padding到8bytes上,虽然会多拷贝几个字节,但是指令数却可以减少(省去类似 `movzbwl/mov ah` 这样的指令)

template <int c, int c2>
char* op(char* RE buf, const char* RE p) {
    // 8 bytes as a unit.
    if constexpr (c <= 8) {
        memcpy(buf, p, 8);
    } else if constexpr (c <= 16) {
        memcpy(buf, p, 16);
    } else if constexpr (c <= 24) {
        memcpy(buf, p, 24);
    } else if constexpr (c <= 32) {
        memcpy(buf, p, 32);
    } else if constexpr (c <= 40) {
        memcpy(buf, p, 40);
    } else if constexpr (c <= 48) {
        memcpy(buf, p, 48);
    } else if constexpr (c <= 56) {
        memcpy(buf, p, 56);
    } else if constexpr (c <= 64) {
        memcpy(buf, p, 64);
    } else {
        static_assert(c <= 64);
    }
    buf += c2;
    return buf;
}

使用buffer来减少系统调用:因为我们最多处理到20位,然后每轮只处理30个字符,所以一轮最多产生600个字符(RESERVE). 可选地使用vmsplice或者是write来进行写入。

if ((buf - head) > (BUFFER_SIZE - RESERVE)) {
    size_t size = buf - head;
    if (use_vmsplice) {
        // ssize_t vmsplice(int fd, const struct iovec* iov, size_t nr_segs, unsigned int flags);
        iovec iov[1] = {
                {.iov_base = head, .iov_len = size},
        };
        vmsplice(1, iov, 1, 0);
        head = (head == buffer[0]) ? buffer[1] : buffer[0];
    } else {
        os_write(1, head, size);
    }
    buf = head;
}

使用vmsplice的话需要使用0/1 buffer, 并且设置pipe size,确保pipe size和单个buffer size相同。

bool fix_pipe_size() {
    int fd = 1;
    int pipe_size = fcntl(fd, F_GETPIPE_SZ);
    if (pipe_size == -1) {
        perror("get pipe size failed.");
        return false;
    }
    fprintf(stderr, "default pipe size: %d\n", pipe_size);

    int ret = fcntl(fd, F_SETPIPE_SZ, BUFFER_SIZE);
    if (ret < 0) {
        perror("set pipe size failed.");
        return false;
    }
    pipe_size = fcntl(fd, F_GETPIPE_SZ);
    if (pipe_size == -1) {
        perror("get pipe size failed.");
        return false;
    }
    fprintf(stderr, "new pipe size: %ld\n", pipe_size);
    return true;
}

int main() {
    // ...
    if (use_vmsplice) {
        bool ok = fix_pipe_size();
        if (!ok) {
            use_vmsplice = false;
            fprintf(stderr, "use_vmsplice disabled!\n");
        }
    }
}

UPDATE(20220813): 后面做了部分修改,我在的机器上带宽差不多是5GB/s左右。这个和机器环境很相关,在我同事的机器上可以翻倍甚至更多。

memcpy不是按照8字节对齐而是按照4字节对齐

template <int c>
char* op(char* RE buf, const char* RE p) {
    constexpr int x = (c + 3) / 4 * 4;
    memcpy(buf, p, x);
    buf += c;
    return buf;
}

在memcpy模式上访问顺序最好能保持一致,可能这样对于prefetch会比较友好

template <int digit>
char* output0(char* RE buf, const char* RE pp) {
    // 11   12   13    14    15        16   17   18   19  20    21
    // 1    fizz  3    4     fizzbuzz  6    7    fizz  9  Buzz Fizz
    MC(pp, digit);
    MC("1\nFizz\n000", 7);

    MC(pp, digit);
    MC("3\n00", 2);
    MC(pp, digit);
    MC("4\nFizzBuzz\n00000", 11);

    MC(pp, digit);
    MC("6\n00", 2);
    MC(pp, digit);
    MC("7\nFizz\n0", 7);

    MC(pp, digit);
    MC("9\nBuzz\nFizz\n0000", 12);
    return buf;
}

对digit管理数据结构包装在一个64字节对象以内,这样可以确保每次拿到digit buffer的话相关对象都可以拿到,减少L1 cache miss. 另外add10上可以略微做得更加紧凑一些,返回最新更新的指针,然后判断begin是否发生变化。

inline char* add10(char* end) {
    char* p = end;
    while (*p == '9') {
        *p = '0';
        p--;
    }
    *p = *p + 1;
    return p;
}

struct DigitContext {
    static constexpr int MAXDIGIT = 20;
    static constexpr int DIGITBUF = MAXDIGIT + 2;

    char digitbuf[DIGITBUF + 8];
    char* begin;
    char* end;
};
static_assert(sizeof(DigitContext) <= 64);
alignas(64) DigitContext digitctx[1];

其实帖子里面这个代码效率是非常高的,大致思想就是

#include <stdio.h>
#include <string.h>
#include <unistd.h>
char buf[416];
char out[65536 + 4096] = "1\n2\nFizz\n4\nBuzz\nFizz\n7\n8\nFizz\n";
int main(int argc, char **argv) {
  const int o[16] = { 4, 7, 2, 11, 2, 7, 12, 2, 12, 7, 2, 11, 2, 7, 12, 2 };
  char *t = out + 30;
  unsigned long long i = 1, j = 1;
  for (int l = 1; l < 20; l++) {
    int n = sprintf(buf, "Buzz\n%llu1\nFizz\n%llu3\n%llu4\nFizzBuzz\n%llu6\n%llu7\nFizz\n%llu9\nBuzz\nFizz\n%llu2\n%llu3\nFizz\nBuzz\n%llu6\nFizz\n%llu8\n%llu9\nFizzBuzz\n%llu1\n%llu2\nFizz\n%llu4\nBuzz\nFizz\n%llu7\n%llu8\nFizz\n", i, i, i, i, i, i, i + 1, i + 1, i + 1, i + 1, i + 1, i + 2, i + 2, i + 2, i + 2, i + 2);
    i *= 10;
    while (j < i) {
      memcpy(t, buf, n);
      t += n;
      if (t >= &out[65536]) {
        char *u = out;
        do {
          int w = write(1, u, &out[65536] - u);
          if (w > 0) u += w;
        } while (u < &out[65536]);
        memcpy(out, out + 65536, t - &out[65536]);
        t -= 65536;
      }
      char *q = buf;
      for (int k = 0; k < 16; k++) {
        char *p = q += o[k] + l;
        if (*p < '7') *p += 3;
        else {
          *p-- -= 7;
          while (*p == '9') *p-- = '0';
          ++*p;
        }
      }
      j += 3;
    }
  }
}

UPDATE(20220825): 又做了一些改进,现在在同事的机器上可以稳定在20GB/s上,而那个asm程序差不多是在40GB/s.

这次的优化思路是使用 代码生成工具,而不是使用模板。代码生成主要的目的是为了可以将要写的内容,通过计算的方式合并在一起,最后按照128bit/256bit写下去。

因为计算开销代价很小,然后每次都可以按照16bytes/32bytes写下去:我估算了一下,如果digit prefix在8个字节的时候,差不多需要4-8条指令(假设6)就可以填满然后写入,所以带宽可以达到 16bytes/6insts. 如果按照3Ghz来计算的话,CPI是0.5, 那么可以达到16GB/s.

代码生成方式和之前模板类似,模板是按照10个一组进行展开,而这个是按照100个一组展开。digit prefix最多可以有16位,所以对于数字有上限(<=10^18).

可以看看其中生成片段:

char* gen_output_2_8(char* RE buf, const char* RE pp) {
uint64_t e0=0,e1=0,e2=0,e3=0;
memcpy(&e0, pp + 0, 8);
__m128i PP = _mm_set_epi64x(e1, e0);
__m128i X = _mm_setzero_si128();
__m128i P, C;
C = _mm_set_epi64x(2682LL, 8820658356000290114LL); // Buzz\nFizz\n
X = C;
X = _mm_or_si128(X, _mm_bslli_si128(PP, 10));
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;
X = _mm_bsrli_si128(PP, 6);
C = _mm_set_epi64x(0LL, 668208LL); // 02\n
X = _mm_or_si128(X, _mm_bslli_si128(C, 2));
X = _mm_or_si128(X, _mm_bslli_si128(PP, 5));
C = _mm_set_epi64x(45004518722LL, 755050480103207728LL); // 03\nFizz\nBuzz\n
X = _mm_or_si128(X, _mm_bslli_si128(C, 13));
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;
C = _mm_bsrli_si128(C, 3);
X = C;
X = _mm_or_si128(X, _mm_bslli_si128(PP, 10));
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;

整个操作单位是128bit. 我也写了一个256bit的版本,其中最大的问题就是没有128bit上这样的bit shift操作。256bit上的bit shift操作是按照128bit lane来单独操作的, 也不是不能写,但是写出来会比较难看,需要4条指令。

    def mm256_merge(x, y, off, size):
        assert size <= 16
        if off == 0:
            return "%s = %s;" % (x, y)

        if off + size <= 16:
            return "%s = _mm256_or_si256(%s, _mm256_bslli_epi128(%s, %s));" % (x, x, y, off)

        # FIXME: not efficient.
        rshift = 16 - off
        if rshift > 0:
            inst = "__m256i t3 = _mm256_bsrli_epi128(t2, %d);" % (rshift);
        elif rshift < 0:
            inst = "__m256i t3 = _mm256_bslli_epi128(t2, %d);" % (-rshift);
        else:
            inst = "__m256i t3 = t2;"

        C = """{{ // mm256_merge({target}, {source}, {shift}, {size});
__m256i t = _mm256_bslli_epi128({source}, {shift});
__m256i t2 = _mm256_permute2f128_si256({source}, {source}, 0x08);
{inst}
{target} = _mm256_or_si256({target}, _mm256_or_si256(t, t3));
}}
""".format(target=x, source=y, shift=off, size=size, rshift=16 - off, inst=inst)
        return C

整个调试过程其实还挺麻烦的,但是好处是,上面代码只需要抽取片段就可以在单独的程序上调试,打印看看自己操作的结果是否正确。

指令还可以继续简化一下,但是其实差别不是太多了,最后输出的代码比如是这样的,整个过程中是不需要中间变量P和C的。

char* gen_output_2_9(char* RE buf, const char* RE pp) {
uint64_t e0=0,e1=0;
memcpy(&e0, pp + 0, 8);
memcpy(&e1, pp + 8, 1);
__m128i PP = _mm_set_epi64x(e1, e0);
__m128i X = _mm_setzero_si128();
__m128i P, C;
X = _mm_set_epi64x(2682LL, 8820658356000290114LL);// (Buzz\nFizz\n >> 0) << 0
X = _mm_or_si128(X, _mm_bslli_si128(PP, 10));
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;
X = _mm_bsrli_si128(PP, 6);
X = _mm_or_si128(X, _mm_set_epi64x(0LL, 11210669948928LL));// (02\n >> 0) << 3
X = _mm_or_si128(X, _mm_bslli_si128(PP, 6));
X = _mm_or_si128(X, _mm_set_epi64x(3458764513820540928LL, 0LL));// (03\nFizz\nBuzz\n >> 0) << 15
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;
X = _mm_set_epi64x(175798901LL, 4758750622441146931LL);// (03\nFizz\nBuzz\n >> 1) << 0
X = _mm_or_si128(X, _mm_bslli_si128(PP, 12));
_mm_storeu_si128((__m128i*)buf, X); /* X = _mm_setzero_si128(); */ buf += 16;