Pytorchスタイルになって書きやすくなったFlaxの新API「NNX」の使用感の確認のため、ALE/Breakout(ブロック崩し)向けにDQNを実装しました。 Jaxとは? ①Numpyの使いやすさ ②柔軟な自動微分 ③マルチCPU/GPU/TPUでの分散並列コンピューティング Flax NNXとは? PyTorchスタイルになったFlaxの新しいAPI 余談: Flax NXXまでの経緯 なぜ強化学習にJax/Flaxを使うのか? 大規模言語モデル向け強化学習 フィジカルAI(AIロボティクス) DQN(Deep-Q-Network)の実装 DQNとは Jax(GPU版)/Flaxのイ…