Fast Inverse Square Root 快速平方根倒数

https://www.youtube.com/watch?v=p8u_k2LIZyo

https://en.wikipedia.org/wiki/Fast_inverse_square_root

IEEE 754 的float表示包括三个部分:

  1. Sign(1bit) 在这里肯定都是0
  2. Exponent(8bit),后面简写E
  3. Mantissa(23bit),后面简写M

那么浮点数表示是 (1 + M/(2^23)) * 2^(E-127).

视频里面解释的非常清楚,大致思想包括下面几个:

  1. 使用log2去将(E-127)部分分离出来
  2. log2(1+x) ~= (x + u), 其中u涉及到魔数的选择,那么log(1+M/(2^23)) ~= M/(2^23) + u
  3. 转换成为整数表示 (E + M/(2^23) ) << 23 = (E<<23) + M. 这就是float的整数表示
  4. 牛顿迭代计算f(x) = 0的话,迭代方法是 x = x0 - f(x0) / f'(x0). 这里f(x) = 1/(x^2) - number = 0
float Q_rsqrt( float number )
{
    long i;
    float x2, y;
    const float threehalfs = 1.5F;

    x2 = number * 0.5F;
    y  = number;
    i  = * ( long * ) &y;                       // evil floating point bit level hacking
    i  = 0x5f3759df - ( i >> 1 );               // what the fuck?
    y  = * ( float * ) &i;
    y  = y * ( threehalfs - ( x2 * y * y ) );   // 1st iteration
//    y  = y * ( threehalfs - ( x2 * y * y ) );   // 2nd iteration, this can be removed

    return y;
}

总之整来整去得到的就是下面这个式子(牛顿迭代法之前一步)

fast-inverse-sqrt-equation.png

这里u是log(1+x) - x的估算。如果假设u=0的话,那么计算出来的浮点数如下,M部分误差在6%左右

In [56]: a = (1<< 23) * 1.5  * (127)

In [57]: b = 0x5f3759df

In [60]: a, b, a-b, (a-b) / (1 << 23)
Out[60]: (1598029824.0, 1597463007, 566817.0, 0.06756985187530518)

我们可以做个试验看看不同u下面,M部分的误差是否会有所改善,以及平均下来u大约在什么范围

In [11]: def test():
    ...:     data = []
    ...:     import math
    ...:     for x2 in range(100):
    ...:         x = x2 * 0.01
    ...:         u = math.log2(1 + x) - x
    ...:         a = (1 << 23) * 1.5 * (127- u)
    ...:         b = 0x5f3759df
    ...:         r = abs(a - b) / (1 << 23)
    ...:         data.append((round(x, 4), round(u, 4), round(r, 4)))
    ...:     data.sort(key = lambda x: x[2])
    ...:     avgu = sum((x[1] for x in data)) / len(data)
    ...:     print('====top10=====')
    ...:     for x in data[:10]:
    ...:         print(x)
    ...:     value = 1.5 * (1 << 23) * (127 - avgu)
    ...:     number = hex(int(value))
    ...:     print('avg u = %.4f, number = %s' % (avgu, number))

结果如下,可以看到u大约取值是在0.0439 - 0.046区间内,平均u是0.0573, 对应的number就是 0x5f34ff97. 至于 0x5f3759df 这个魔数,对应的u是 0.0450466.

====top10=====
(0.81, 0.046, 0.0014)
(0.82, 0.0439, 0.0017)
(0.13, 0.0463, 0.0019)
(0.12, 0.0435, 0.0023)
(0.8, 0.048, 0.0044)
(0.83, 0.0418, 0.0048)
(0.14, 0.049, 0.006)
(0.11, 0.0406, 0.0067)
(0.79, 0.05, 0.0074)
(0.84, 0.0397, 0.008)
avg u = 0.0573, number = 0x5f34ff97

按照这个思路,我们也可以写个 sqrt(x) 的实现,只不过在牛顿迭代的时候,有除法计算,而且需要迭代个两次才能得到比较准确的结果,有点不太讲究。

float Q_sqrt( float number )
{
    long i;
    float y;

    // X = int(127-u) * (1 << 22)
    // u = 0.0573
    #define X 0x1fbc5532
    // u = 0.045
    // #define X 0x1fbd1df4
    y  = number;
    i  = * ( long * ) &y;
    i = (i >> 1) + X;
    // to avoid negative value.
    // i = i & 0x7fffffff;
    y  = * ( float * ) &i;
    y = 0.5f * (y + number / y);
    y = 0.5f * (y + number / y);
    return y;
}

int main() {
    for(int i=10;i<=1200;i+=20) {
        float x = Q_sqrt(i);
        cout << "i = " << i << ", x=sqrt(i) = " << x << ", x*x= " << x * x << endl;
    }
    return 0;
}