PyTorch
深層学習フレームワーク。本ヴォルトの ML 実装(generative-adversarial-network、bert チャットボット、CNN で MNIST など)の共通基盤。
よく使う要素
- Tensor:
torch.from_numpy()↔x.detach().numpy()。detach()は勾配情報を切る。形状変換はview/reshape。 - 推論モード:
model.eval()(Dropout/BatchNorm の振る舞いが変わる)+torch.no_grad()、またはtorch.inference_mode()(推論専用 Tensor が用意される)。torch.set_grad_enabled(bool)で train/val をまとめて書ける。 - メモリ:
torch.cuda.empty_cache()で GPU メモリ解放。 - TensorBoard:
from torch.utils.tensorboard import SummaryWriterで損失や生成画像グリッド(add_images)を記録。TensorboardX / pytorch-ignite も併用される。 - Dataset/DataLoader:
datasets(MNIST, CelebA 等)+transforms(ToTensor, Normalize)+DataLoader(batch_size, shuffle, drop_last)。
エコシステムとの関係
PyTorch 計算グラフは ml-compilers(TVM / Glow など)が最適化対象とする IR の入口でもある。VLA 実装(LeRobot の Pi0Policy など)も torch.nn.Module 上に構築され、推論高速化は realtime-vla で Triton カーネルに置き換えられる。
関連
- ml-compilers / realtime-vla
- _moc-ml-robotics(ml-robotics クラスタの atomic ノート群)