前回はParallel Scanのforwardについてある程度確かめた。今回はParallel Scanの逆伝播と状態空間モデルの離散化について確認する。 Parallel Scanの逆伝播 Mambaの論文ではメモリ使用量を抑えるために、値を保持しておくのではなく逆伝播時に再計算をすると書かれている。mamba.pyでもそのようにされているようで、かつParallel Scanの逆伝播はpscan_revという新たな関数を用いて実装されている。これは「flip the input, call pscan, then flip the output」を行う操作とのことである。これで上手く計算…