Kahan Summation算法

Kahan Summation算法

Kahan Summation算法 #

在一些“积少成多”的计算过程中,比如在机器学习中,我们经常要计算海量样本计算出来的梯度或者 loss,于是会出现几亿个浮点数的相加。每个浮点数可能都差不多大,但是随着累积值的越来越大,就会出现“大数吃小数”的情况。

我们可以做一个简单的实验,用一个循环相加 2000 万个 1.0f,最终的结果会是 1600 万左右,而不是 2000 万。这是因为,加到 1600 万之后的加法因为精度丢失都没有了。这个代码比起上面的使用 2000 万来加 1.0 更具有现实意义。

public class FloatPrecision {
  public static void main(String[] args) {
    float sum = 0.0f;
    for (int i = 0; i < 20000000; i++) {
      float x = 1.0f;
      sum += x;
    }
    System.out.println("sum is " + sum);
  }
}

对应的输出结果是:

sum is 1.6777216E7

可以用Kahan Summation的算法来解决这个问题。

public class KahanSummation {
  public static void main(String[] args) {
    float sum = 0.0f;
    float c = 0.0f;
    for (int i = 0; i < 20000000; i++) {
      float x = 1.0f;
      float y = x - c;
      float t = sum + y;
      c = (t-sum)-y;
      sum = t;
    }
    System.out.println("sum is " + sum);
  }
}

对应的输出结果就是:

sum is 2.0E7

其实这个算法的原理其实并不复杂,就是在每次的计算过程中,都用一次减法,把当前加法计算中损失的精度记录下来,然后在后面的循环中,把这个精度损失放在要加的小数上,再做一次运算。

如果你对这个背后的数学原理特别感兴趣,可以去看一看Wikipedia 链接里面对应的数学证明,也可以生成一些数据试一试这个算法。这个方法在实际的数值计算中也是常用的,也是大量数据累加中,解决浮点数精度带来的“大数吃小数”问题的必备方案。

Viewpoint #

From #

16 | 浮点数和定点数(下):深入理解浮点数到底有什么用?