GAN

GeneratorとDiscriminatorからなるモデル

Model

Generator

class Generator(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super(Generator, self).__init__()
        sizes = [in_dim, 256, 512, 1024, out_dim]
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        slope = 0.2
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.tanh(self.fc4(x))
        return x

Discriminator

class Discriminator(nn.Module):
    def __init__(self, in_dim: int):
        super(Discriminator, self).__init__()
        sizes = [in_dim, 1024, 512, 256, 1]
        self.fc1 = nn.Linear(sizes[0], sizes[1])
        self.fc2 = nn.Linear(sizes[1], sizes[2])
        self.fc3 = nn.Linear(sizes[2], sizes[3])
        self.fc4 = nn.Linear(sizes[3], sizes[4])
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # leaky relu slope
        slope = 0.2
        # dropout rate
        dropout = 0.5
        x = F.leaky_relu(self.fc1(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc2(x), slope)
        x = F.dropout(x, dropout)
        x = F.leaky_relu(self.fc3(x), slope)
        x = F.dropout(x, dropout)
        x = F.sigmoid(self.fc4(x))
        return x

Train

Discriminator

def train_discriminator(
    generator: Generator, 
    discriminator: Discriminator, 
    optim_discriminator, 
    x: torch.Tensor,
    criterion,
    batch_size: int,
    step: int,
):
    discriminator.zero_grad()
    x_real, y_real = x.view(-1, out_dim).to(device), torch.ones(batch_size, 1).to(device)
    x_real, y_real = Variable(x_real), Variable(y_real)
    
    d_output = discriminator(x_real)
    # print(f'd_output: {d_output.shape}, y_real: {y_real.shape}')
    loss_real = criterion(d_output, y_real)
 
    # train discriminator with fake data
    z = Variable(torch.randn(batch_size, in_dim, device = device))
    x_fake, y_fake = generator(z), torch.zeros(batch_size, 1).to(device)
    
    d_output = discriminator(x_fake)
    loss_fake = criterion(d_output, y_fake)
 
    loss = loss_real + loss_fake
    loss.backward()
    optim_discriminator.step()
    l = loss.item()
    del loss
    writer.add_scalar('loss/discriminator', l, step)

Generator

def train_generator(
    generator: Generator,
    discriminator: Discriminator,
    optim_generator,
    criterion,
    batch_size,
    step: int,
):
    generator.zero_grad()
    z = Variable(torch.randn(batch_size, in_dim, device = device))
    y = Variable(torch.ones(batch_size, 1).to(device))
 
    g_output = generator(z)
    d_output = discriminator(g_output)
    loss_generator = criterion(d_output, y)
 
    loss_generator.backward()
    optim_generator.step()
    loss = loss_generator.item()
    writer.add_scalar('loss/generator', loss, step)
    del loss_generator

参考文献