相关包
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
%matplotlib inline
训练过程封装
def fit(epoch, model, trainloader, testloader):
correct = 0
total = 0
running_loss = 0
for x, y in trainloader:
y_pred = model(x)
loss = loss_func(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = torch.argmax(y_pred, dim = 1)
correct += (y_pred == y).sum().item()
total += y.size(0)
running_loss += loss.item()
epoch_acc = correct / total
epoch_loss = running_loss / len(trainloader.dataset)
test_correct = 0
test_total = 0
test_running_loss = 0
with torch.no_grad():
for x, y in testloader:
y_pred = model(x)
loss = loss_func(y_pred, y)
y_pred = torch.argmax(y_pred, dim = 1)
test_correct += (y_pred == y).sum().item()
test_total += y.size(0)
test_running_loss += loss.item()
epoch_test_acc = test_correct / test_total
epoch_test_loss = test_running_loss / len(testloader.dataset)
print('epoch: ', epoch,
'loss: ', round(epoch_loss, 3),
'accuracy: ', round(epoch_acc, 3),
'test_loss: ', round(epoch_test_loss, 3),
'test_accuracy: ', round(epoch_test_acc, 3))
return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
这里的correct指的是每一轮中分类正确的样本数
total指的是每一轮总的样本数
running_loss指的是在一轮中,损失值的总和
初始化
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
epochs = 100
模型训练
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
训练可视化
plt.plot(range(1, epochs + 1), train_loss, label = 'train_loss')
plt.plot(range(1, epochs + 1), test_loss, label = 'test_loss')
plt.legend()
plt.plot(range(1, epochs + 1), train_acc, label = 'train_acc')
plt.plot(range(1, epochs + 1), test_acc, label = 'test_acc')
plt.legend()