PyTorch

list from #pytorch 

Tensor

torch.from_numpy() <-> x.detach().numpy()

  • detach(): 勾配情報を無視する

Context Manager

  • model.eval(): DropOutBatchNorm の振る舞いが変わる
  • torch.set_grad_enabled(bool): でtrain, valまとめて書ける
    • true: model.train()
    • false: model.eval()

推論時

model.eval()
with torch.no_grad():
	model(data)
with torch.inference_mode():
	model(data)
  • 推論専用のTensorが用意されてRAIIされるらしい.

メモリ節約

  • torch.cuda.empty_cache(): GPUのメモリを開放する

TensorBoard

$ pip install tensorboard
from torch.utils.tensorboard import SummaryWriter
import datetime
 
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
writer = SummaryWriter(f'runs/mnist_gan_exp_{timestamp}')

画像グリッドを送る

z = torch.randn(batch_size, in_dim).to(device)
images = generator(z).view(-1, *shape)
writer.add_images(f'generated_image', images, epoch)

Dataset/Dataloader

from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])
data = datasets.CelebA(root='./dataset', download=True, transform=transform)
size = 10000
data, _ = random_split(data, [size, len(data) - size])
 
batch_size: int = 256
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)

Transforms

参考文献