GAN
GeneratorとDiscriminatorからなるモデル
Model
Generator
class Generator(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super(Generator, self).__init__()
sizes = [in_dim, 256, 512, 1024, out_dim]
self.fc1 = nn.Linear(sizes[0], sizes[1])
self.fc2 = nn.Linear(sizes[1], sizes[2])
self.fc3 = nn.Linear(sizes[2], sizes[3])
self.fc4 = nn.Linear(sizes[3], sizes[4])
def forward(self, x: torch.Tensor) -> torch.Tensor:
slope = 0.2
x = F.leaky_relu(self.fc1(x), slope)
x = F.leaky_relu(self.fc2(x), slope)
x = F.leaky_relu(self.fc3(x), slope)
x = F.tanh(self.fc4(x))
return xDiscriminator
class Discriminator(nn.Module):
def __init__(self, in_dim: int):
super(Discriminator, self).__init__()
sizes = [in_dim, 1024, 512, 256, 1]
self.fc1 = nn.Linear(sizes[0], sizes[1])
self.fc2 = nn.Linear(sizes[1], sizes[2])
self.fc3 = nn.Linear(sizes[2], sizes[3])
self.fc4 = nn.Linear(sizes[3], sizes[4])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# leaky relu slope
slope = 0.2
# dropout rate
dropout = 0.5
x = F.leaky_relu(self.fc1(x), slope)
x = F.dropout(x, dropout)
x = F.leaky_relu(self.fc2(x), slope)
x = F.dropout(x, dropout)
x = F.leaky_relu(self.fc3(x), slope)
x = F.dropout(x, dropout)
x = F.sigmoid(self.fc4(x))
return xTrain
Discriminator
def train_discriminator(
generator: Generator,
discriminator: Discriminator,
optim_discriminator,
x: torch.Tensor,
criterion,
batch_size: int,
step: int,
):
discriminator.zero_grad()
x_real, y_real = x.view(-1, out_dim).to(device), torch.ones(batch_size, 1).to(device)
x_real, y_real = Variable(x_real), Variable(y_real)
d_output = discriminator(x_real)
# print(f'd_output: {d_output.shape}, y_real: {y_real.shape}')
loss_real = criterion(d_output, y_real)
# train discriminator with fake data
z = Variable(torch.randn(batch_size, in_dim, device = device))
x_fake, y_fake = generator(z), torch.zeros(batch_size, 1).to(device)
d_output = discriminator(x_fake)
loss_fake = criterion(d_output, y_fake)
loss = loss_real + loss_fake
loss.backward()
optim_discriminator.step()
l = loss.item()
del loss
writer.add_scalar('loss/discriminator', l, step)Generator
def train_generator(
generator: Generator,
discriminator: Discriminator,
optim_generator,
criterion,
batch_size,
step: int,
):
generator.zero_grad()
z = Variable(torch.randn(batch_size, in_dim, device = device))
y = Variable(torch.ones(batch_size, 1).to(device))
g_output = generator(z)
d_output = discriminator(g_output)
loss_generator = criterion(d_output, y)
loss_generator.backward()
optim_generator.step()
loss = loss_generator.item()
writer.add_scalar('loss/generator', loss, step)
del loss_generator