PyTorch

深層学習フレームワーク。本ヴォルトの ML 実装(generative-adversarial-networkbert チャットボット、CNN で MNIST など)の共通基盤。

よく使う要素

  • Tensor: torch.from_numpy()x.detach().numpy()detach() は勾配情報を切る。形状変換は view/reshape
  • 推論モード: model.eval()(Dropout/BatchNorm の振る舞いが変わる)+ torch.no_grad()、または torch.inference_mode()(推論専用 Tensor が用意される)。torch.set_grad_enabled(bool) で train/val をまとめて書ける。
  • メモリ: torch.cuda.empty_cache() で GPU メモリ解放。
  • TensorBoard: from torch.utils.tensorboard import SummaryWriter で損失や生成画像グリッド(add_images)を記録。TensorboardX / pytorch-ignite も併用される。
  • Dataset/DataLoader: datasets(MNIST, CelebA 等)+ transforms(ToTensor, Normalize)+ DataLoader(batch_size, shuffle, drop_last)。

エコシステムとの関係

PyTorch 計算グラフは ml-compilers(TVM / Glow など)が最適化対象とする IR の入口でもある。VLA 実装(LeRobot の Pi0Policy など)も torch.nn.Module 上に構築され、推論高速化は realtime-vla で Triton カーネルに置き換えられる。

関連