Recently I’m working on utilizing GANs to generate skeleton level interactions and coincidentally I found this interesting dataset of pokemon images on Kaggle. Wouldn’t it be interesting to see what kind of new pokemons can be generated by GANs.

1. What are GANs

GANs, also known as Generative Adversarial Networks, typically contains two neural networks. One is called generator $G$, which generate fake data and tries to fool the other neural network, called discriminator $D$, whose job is to distinguish what data are real and what are fake data generated by $G$.

Illustration of GANs

For a mathematical formulation, the $G$ take a random variable $z$ sampled from distribution $p(z)$ as input and generate a fake data $G(z)$. $D$ take real data $x$ sampled from $p_{data}(x)$ of fake data $G(z)$ as it input and output the possibility that the input input is real data. The goal of $G$ is to generate data as real as possible, ie it want to maximize $D(G(z))$. Meanwhile, the goal of $D$ is to distinguish fake and real data, ie it want to maximize $D(x)$ and minimize $D(G(z))$. Putting all these together, the GANs framework is actually these two networks playing a 0-sum game. Mathematically , the optimization goal is

gans formulation

2. Implementing GANs with pytorch

There are several kinds of GANs and I choose wGAN-gp because of its stablity during training.

GAN model

Generator network:

For $G$ I use 4 transposed convolutional layer to rebuild a 3*128*128 RGB image.

class G_net(nn.Module):
    def __init__(self,in_dim=latent_dim):
        super(G_net,self).__init__()
        self.fc1= nn.Sequential(
            nn.Linear(in_dim,units[4]),
            nn.ReLU(),
            nn.Linear(units[4],units[3]*fs[0]*fs[1]),
            nn.ReLU(),
            nn.BatchNorm1d(units[3]*fs[0]*fs[1])
        )
        self.ct1 = nn.Sequential(
            nn.ConvTranspose2d(units[3],units[2],k_size[3],stride=strides[3],padding=padding[3],output_padding=strides[3]/2),
            nn.BatchNorm2d(units[2]),
            nn.ReLU()
        )#[64,12,12]
        self.ct2 = nn.Sequential(
            nn.ConvTranspose2d(units[2],units[1],k_size[2],stride=strides[2],padding=padding[2],output_padding=strides[2]/2),
            nn.BatchNorm2d(units[1]),
            nn.ReLU()
        )#[32,27,27]
        self.ct3 = nn.Sequential(
            nn.ConvTranspose2d(units[1],units[0],k_size[1],stride=strides[1],padding=padding[1],output_padding=strides[1]/2),
            nn.BatchNorm2d(units[0]),
            nn.ReLU()
        )#[3,57,57]
        self.ct4 = nn.Sequential(
            nn.ConvTranspose2d(units[0],small_image_size[0],k_size[0],stride=strides[0],padding=padding[0],output_padding=strides[0]/2),
            nn.Tanh()
        )#[3,289,289]
    def forward(self,X):
        X = self.fc1(X)
        X = self.ct1(X.view(-1,units[3],fs[0],fs[1]))
        X = self.ct2(X)
        X = self.ct3(X)
        return self.ct4(X)

Discriminator network:

For $D$ I use 4 convolutional layers followed by 2 fully connected layers. According to wgan-gp’s paper, the last layer should not use sigmoid activation.

class D_net(nn.Module):
    def __init__(self):
        super(D_net,self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(image_size[0],units[0],k_size[0],strides[0],padding=padding[0]),
            nn.BatchNorm2d(units[0]),
            nn.LeakyReLU(0.2)
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(units[0],units[1],k_size[1],strides[1],padding=padding[1]),
            nn.BatchNorm2d(units[1]),
            nn.LeakyReLU(0.2)
        )
        self.conv3=nn.Sequential(
            nn.Conv2d(units[1],units[2],k_size[2],strides[2],padding=padding[2]),
            nn.BatchNorm2d(units[2]),
            nn.LeakyReLU(0.2)
        )
        self.conv4=nn.Sequential(
            nn.Conv2d(units[2],units[3],k_size[3],strides[3],padding=padding[3]),
            nn.BatchNorm2d(units[3]),
            nn.LeakyReLU(0.2)
        )
        self.fc1 = nn.Linear(units[3]*fs[0]*fs[1],units[4])
        self.dp = nn.Dropout(0.5)
        self.d_out = nn.Linear(units[4],1)
    def forward(self,X):
        X = self.conv1(X)
        X = self.conv2(X)
        X = self.conv3(X)
        X = self.conv4(X)
        X = X.view((-1,units[3]*fs[0]*fs[1]))
        X = self.dp(F.leaky_relu(self.fc1(X)))
        out = self.d_out(X)
        return out

Gradient penalty

LAMBDA = 10# Gradient penalty lambda hyperparameter

def loss_with_penalty(D,rdata,fdata):
    alpha = torch.rand(rdata.size())
    alpha = alpha.cuda(device) if use_GPU else alpha

    interpolates = alpha * rdata + ((1 - alpha) * fdata)

    interpolates = Variable(interpolates, requires_grad=True)
    if use_GPU:
        interpolates = interpolates.cuda(device)

    disc_interpolates = D(interpolates)
    gradients = grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(device) if use_GPU else torch.ones(disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

Training procedure

Data Preprocess

imgs_dir = glob.glob(data_dir+'*')
imgs = [Image.open(fil).resize(image_size[1:]) for fil in imgs_dir]
class PK_DATASET(Dataset):
    def __init__(self,imgs):
        self.data=imgs
        self.trans=Compose([
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.trans(self.data[idx])
Pk_dataset=PK_DATASET(imgs)
Pk_dataloader=DataLoader(Pk_dataset,batch_size=batch_size,num_workers=1,shuffle=True)

Initialize NNs

latent_dim=128 # dimension for latent variable
def weights_init(m):#wgan need carefully initialized weights.
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
d_model = D_net()
g_model = G_net(latent_dim)
d_model.apply(weights_init)
g_model.apply(weights_init)

Training

d_iter = 1
g_iter = 1
epoch = 3000
for e in range(epoch):
    for data in Pk_dataloader:

        for i in range(d_iter):
            #real data
            for p in d_model.parameters():
                p.requires_grad = True
            d_model.zero_grad()
            true_data = Variable(data)
            if use_GPU:
                true_data=true_data.cuda(device)
            d_true_score = d_model(true_data)
            true_loss = -d_true_score.mean()
            #fake data
            noise = get_noise(true_data.size()[0])
            if use_GPU:
                noise = noise.cuda(device)
            fake_data = g_model(noise)
            d_fake_score = d_model(fake_data)
            fake_loss = d_fake_score.mean()
            w_loss = loss_with_penalty(d_model,true_data.data,fake_data.data)
            loss =true_loss+fake_loss+w_loss
            loss.backward()
            d_optimizer.step()
        for i in range(g_iter):
            #train G
            g_model.zero_grad()
            for p in d_model.parameters():
                p.requires_grad = False
            noise = get_noise()
            if use_GPU:
                noise = noise.cuda(device)
            fake_data = g_model(noise)
            g_score = d_model(fake_data)
            g_loss = -g_score.mean()
            g_loss.backward()
            g_optimizer.step()
    if e%10==0:
        fake_imgs=fake_data.cpu().data*0.5 +0.5
        img = ToPILImage()(fake_imgs[3])
        img.save('result/%s/G_result/iter_%d.png'%(code,e))
        torch.save(d_model,'result/%s/D_checkpoint/iter_d_%d.pt'%(code,e))
        torch.save(g_model,'result/%s/G_checkpoint/iter_g_%d.pt'%(code,e))

3.result

After 3000 epochs of training here are some new pokemons generated by $G$

As a demonstration of how $G$ improved:

Full source code can be found in this github repository