GAN的Pytorch实现
版本V1
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch import optim as optim
import matplotlib
matplotlib.use('AGG')#或者PDF, SVG或PS
import matplotlib.pyplot as plt
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
batch_size = 100
# MNIST dataset
dataset = datasets.MNIST(root='./data/', train=True, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
print(f"Length of total dataset = {len(dataset)},
Length of dataloader with having batch_size of {batch_size} = {len(dataloader)}")
dataiter = iter(dataloader)
images,labels = dataiter.next()
print(torch.min(images),torch.max(images))
class GeneratorModel(nn.Module):
def __init__(self):
super(GeneratorModel, self).__init__()
input_dim = 100
output_dim = 784
# <----------D和G的非输出层激活函数都是LeakyReLU()函数--------->
self.hidden_layer1 = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2)
)
self.hidden_layer2 = nn.Sequential(
nn.Linear(256, 512),
nn.LeakyReLU(0.2)
)
self.hidden_layer3 = nn.Sequential(
nn.Linear(512, 1024),
nn.LeakyReLU(0.2)
)
# <----------G的最后一层激活函数是Tanh()函数--------->
self.hidden_layer4 = nn.Sequential(
nn.Linear(1024, output_dim),
nn.Tanh()
)
def forward(self, x):
output = self.hidden_layer1(x)
output = self.hidden_layer2(output)
output = self.hidden_layer3(output)
output = self.hidden_layer4(output)
return output.to(device)
class DiscriminatorModel(nn.Module):
def __init__(self):
super(DiscriminatorModel, self).__init__()
input_dim = 784
output_dim = 1
self.hidden_layer1 = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden_layer2 = nn.Sequential(
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
self.hidden_layer3 = nn.Sequential(
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
# <----------D的最后一层激活函数是Sigmoid()函数--------->
self.hidden_layer4 = nn.Sequential(
nn.Linear(256, output_dim),
nn.Sigmoid()
)
def forward(self, x):
output = self.hidden_layer1(x)
output = self.hidden_layer2(output)
output = self.hidden_layer3(output)
output = self.hidden_layer4(output)
return output.to(device)
discriminator = DiscriminatorModel()
generator = GeneratorModel()
discriminator.to(device)
generator.to(device)
print(generator,"
",discriminator)
# <----------交叉熵损失函数---------->
criterion = nn.BCELoss()
# <----------Adam优化器---------->
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
num_epochs = 100
batch = 100
outputs=[]
# Losses & scores
losses_g = []
losses_d = []
real_scores = []
fake_scores = []
for epoch_idx in range(num_epochs):
start_time = time.time()
for batch_idx, data_input in enumerate(dataloader):
real = data_input[0].view(batch, 784).to(device) # batch_size X 784
batch_size = data_input[1] # batch_size
noise = torch.randn(batch,100).to(device)
fake = generator(noise) # batch_size X 784
disc_real = discriminator(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = discriminator(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
# <----------D_loss是lossD_real+lossD_fake的和---------->
lossD = (lossD_real + lossD_fake) / 2
real_score = torch.mean(disc_real).item()
fake_score = torch.mean(disc_fake).item()
d_optimizer.zero_grad()
lossD.backward(retain_graph=True)
d_optimizer.step()
gen_fake = discriminator(fake).view(-1)
# <----------G_loss是使向D输入fake_img,输出的值向1靠近--------->
lossG = criterion(gen_fake, torch.ones_like(gen_fake))
g_optimizer.zero_grad()
lossG.backward()
g_optimizer.step()
if ((batch_idx + 1)% 600 == 0 and (epoch_idx + 1)%10 == 0):
print("Training Steps Completed: ", batch_idx)
with torch.no_grad():
generated_data = fake.cpu().view(batch, 28, 28)
real_data = real.cpu().view(batch, 28, 28)
i = 0
j = 0
plt.figure(figsize=(10,2))
print("Real Images")
for x in real_data:
if(i>=10): break
plt.subplot(2,10,i+1)
plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
i = i+1
plt.title("on "+str((epoch_idx + 1))+ "th epoch")
plt.show()
print("Generated Images")
plt.figure(figsize=(10,2))
for x in generated_data:
if(j>=10): break
plt.subplot(2,10,j+1)
plt.imshow(x.detach().numpy(), interpolation='nearest',cmap='gray')
j = j+1
plt.show()
outputs.append((epoch_idx,real,fake))
losses_g.append(lossG)
losses_d.append(lossD)
real_scores.append(real_score)
fake_scores.append(fake_score)
print('Epochs [{}/{}] & Batch [{}/{}]: loss_d: {:.4f}, loss_g: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}, took time: {:.0f}s'.format(
(epoch_idx+1), num_epochs, batch_idx+1, len(dataloader),lossD,lossG,real_score,fake_score,time.time()-start_time))
if epoch_idx % 10 == 0:
plt.plot(losses_d, '-')
plt.plot(losses_g, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses')
plt.savefig('Losses.jpg')
plt.show()
plt.close()
plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('epoch')
plt.ylabel('score')
plt.legend(['Real', 'Fake'])
plt.title('Scores')
plt.savefig('Scores.jpg')
plt.show()
plt.close()
# Save trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')