指数関数に関わる浮動小数点数のクイズ(解答編)

melancholic afternoon

昼休みに光成さんに「日記に書いたパズルが、解くのに必要な量のヒントが与えられているか、それとも足りないのかを(この手の問題をやりこんでいる)自分では判断できないので解いてみて欲しい」と言われたので解いてみました。結果から言うと、僕が90分考えれば解ける程度のレベルでした。(簡単という意味ではない)

次の関数はfmat.hppのexp()に使われているロジックをCで書いたものである. そこに現れる各定数は何を意味するのか, どういうロジックなのか答えよ(ただし -88 <= x <= 88としてよい -- そうじゃなければ0か∞なので).

unsigned int tbl[1024]; // 実行時に一度だけ初期化される

union fi {
    float f;
    unsigned int i;
};

float expC(float x)
{
    float t1 = x * 1.4773197e3 + 12582912; // (1)
    fi fi;
    fi.f = t1;
    float t2 = x - (t1 - 12582912) * 6.7690155e-4; // (2)
    unsigned int u = ((fi.i + 130048) >> 10) << 23; // (3)
    fi.i = u | tbl[fi.i & 1023]; // (4)
    return (1 + t2) * fi.f; // (5)
}

それではネタバレ防止のために適当な数列を挟んでから僕がどうやって考えたかの解説を:

52

26

13

40

20

10

5

16

8

4

2

1

まず、floatって何だ?という基礎を確認する。浮動小数点数 - Wikipedia。floatはIEEE754方式では1ビットの符号、8ビットの指数部, 23ビットの仮数分の仮数部からなる。



次に12582912とか130048がどういう値なのか2進法表現に変えてみる。もちろん1023は111111111ね。

In [12]: def to2(x):
   ....:     buf = []
   ....:     while x:
   ....:         buf.insert(0, x % 2)
   ....:         x /= 2
   ....:     return "".join(map(str, buf))
   ....:

In [13]: to2(10)
Out[13]: '1010'

In [14]: to2(130048)
Out[14]: '11111110000000000'

In [15]: to2(12582912)
Out[15]: '110000000000000000000000'

In [16]: len(_)
Out[16]: 24

あれ、24ビットもある。floatの仮数部の精度は23ビットしかないのに。


次に1.4773197e3とかがなんなのかを確認するためにこんなコードを書いてみる。

typedef union Float {
  float    f;
  unsigned i;
} Float;

int main(){
  Float a;
  a.f = 6.7690155e-4;
  printf("%d\n", a.i);
}

がしかし、ビット表現を見てもよくわからない。

0 10001001 01110001010101000111011 // a
0 10010110 10000000000000000000000 // b
0 01110100 01100010111001000011000 // c

んー。



しばらく数値だけに着目して色々試してみたけど成果が得られず、ソースコードを読むことにする。130048 は 11111110000000000 1が7個0が10個だった。つまり、

unsigned int u = ((fi.i + 130048) >> 10) << 23; // (3)
fi.i = u | tbl[fi.i & 1023]; // (4)

というコードは「下10ビットにテーブル引き用の値が入っていて、その上の7ビットには指数部に入れるべき値がオフセットなしで入っている」ということか。つまりこれが終わった時点でfiの値はテーブルから取ってきた値の指数部を書き換えたものになっているわけだから

    return (1 + t2) * fi.f;

というのはつまりテーブル引きした値から補正するための情報がt2に入っているということだな。


xと(t1 & 130048) >> 10の関係を調べる。

int main(){
  Float a;
  for(float x=0.0; x<10; x+=0.1){
    a.f = x * 1.4773197e3 + 12582912;    
    printf("%f, %d\n", x, (a.i & 130048) >> 10);
  }
}

結果:

0.000000, 0
...
0.600000, 0
0.700000, 1
...
1.300000, 1
1.400000, 2

ふむふむ。0.7ごとに1増えるのか。exp(0.7)はおよそ2だろうな。

In[64]: 1024 / log(2)
Out[64]: 1477.3197218702985

正解。1.4773197e3の意味はわかった。xにこの値を掛けると、x * 1024 / log(2)はxがlog(2)倍になると1024増える。xがlog(2)倍になるということはexp(x)が2倍になるということだ。なのでこの値の下10ビットを捨てて指数部オフセットの1111111を足してやればexp(x)の指数部として使えるというわけだ。

bが必要な理由もおのずと明らかだ。「下10ビットを捨てて」などとまるで整数のビット演算のように語っているけども、これは「浮動」小数点数なので勝手に小数点が移動してしまう。値が1だとしても、最下位ビットが1になってくれない。そこで12582912(これは3 << 22)を足して、浮動小数点数の一番下のビットが整数で言うところの1の位になるようにしている。1 << 23 ではなく 3 << 22 なのは今回 x * 1.4773197e3 が負の値を取りうるからだ。1 << 23 では負の値が来たときにくり下がりで24ビット目が0になってしまう。



最後に誤差の見積に関して。x * aはbを足したことで小数点以下が丸められてしまっている。そこでbを引いてaの逆数を掛けることで元の形に戻してやって、元々のxから引くことで「テーブル引きのために丸めたことでどれくらいxがずれたか(Δx)」を求める。これがt2。そしてテーブル引きした値に (1 + t2)を掛けている。これはテーブル引きした値に x * Δx を足しているわけだ。

    return (1 + t2) * fi.f;

exp(x)の導関数はxなのでこれはx軸方向のズレに傾きを掛けて一次近似で求めていることになる。ってことは誤差としてはxの2乗以降の項が残ることになる。Δxは1/2048未満なのでその2乗の誤差が入っても22ビットくらいの精度がある。exp(x)のマクローリン展開の2次の項はx^2/2だから23ビットあるんじゃないかという気もする。