VLA 推論の計算グラフ最適化

pi0 級の VLA を単一コンシューマ GPU(RTX 4090)で 30 FPS(27.3 ms / 2 view) で走らせるための、推論パイプラインのエンジニアリング手法群。realtime-vla の中核技術で、naive PyTorch 実装 (105 ms) → CUDA Graph → グラフ簡略化 → カーネル最適化と段階的に縮める。

1. CPU オーバーヘッドの除去(CUDA Graph)

pi0 の1推論ステップで起動されるカーネルは1000以上。Python → CUDA ドライバのディスパッチが律速になる。Transformer ブロックには動的分岐がなくバッファポインタも固定できるため、CUDA Graph で1度ストリームを記録し以後 replay() するだけにする。GPU とドライバだけでカーネルが連鎖起動され、Python 実行コストが消える。これだけで約2倍高速化。

LeRobot 実装(HuggingFace Transformers ベース)では、AE 1層あたり約21カーネル × 18層 × 10 denoise step = 約3,780回の起動 + copy.deepcopy(past_key_values) の毎ステップ複製が支配的だった。Realtime-VLA は融合で ~5 kernels/層 = ~900回まで削減。

2. 計算グラフの等価変換(簡略化)

コンパイラの定数畳み込みに相当する書き換えで MAC とカーネル数を削減(7–8 ms 改善):

  • RMSNorm affine 吸収: RMSNorm の学習スケール γ を後続線形層の重みに事前乗算。両者とも線形なので結合則で融合でき、RMSNorm は正規化のみに簡略化。
  • Action-Time Encoder 折り畳み: denoise の時間ステップは {1.0, 0.9, …, 0.1} の10通りのみ。time embedding + style projection を __init__ で全計算しテーブル化、SiLU 直前の bias まで融合。
  • QKV 融合: Q/K/V の3射影行列を1つの大行列に結合し、結果をスライスで取り出す。RoPE の cos/sin も重みに事前融合。

3. カーネル内部の最適化

24 個の GEMM に分解し、各々を専用 Triton カーネルで実装:

  • GEMM タイルチューニング (§4.1): cuBLAS のデフォルトタイルが最適でない場合、BLOCK_SIZE_M/N/K を手動探索(約1.5 ms 改善)。
  • 17層実行: Encoder 最終層は KV キャッシュのみ AE に渡すので attention/FFN を省略(0.7 ms)。
  • Gated Linear 融合 (§4.2): gate_projup_proj は同一入力への独立 matmul。入力タイル1ロードで2重みタイルを処理し GELU(gate)*up のみ書き戻す(1.7 ms)。
  • Partial Split-k (§4.3): 512×1152×1152 は 64×64 タイルで144ブロックとなり 128 SM に不均等。512×1152×1024(均等)+ 512×1152×128(split-2)に分割し1カーネル化。
  • スカラー演算統合 (§4.4): bias・残差・活性化を GEMM のエピローグに融合。RMSNorm はトークン統計を別バッファに先計算し、次の GEMM 累積後に正規化(約4 ms)。

4. システム全体の作法

  • 画像リサイズ: カメラ ISP 出力を 224×224 に近い 240×320 等にし、手書きカーネルで 60 µs 以下に。
  • ピンドメモリ / ゼロコピーで CPU↔GPU 転送を最小化。全中間テンソルを __init__ で静的確保し、forward 中に動的割り当てを起こさない(CUDA Graph 互換のため必須)。

理論下限との差は roofline-model で約30%と見積もられた。次の発展形が full-streaming-inference

関連: triton-language / realtime-vla-v2 / flow-matching