torchfurnace
torchfurnace 是一个集快速训练模型,日志管理,模型checkpoints管理,tensorboard可视化, I/O 加速,模型大小统计于一身的工具包。
使用这个工具包可以快速构建一个深度学习训练,不需要自己写各种训练逻辑,对于已经定义好的模型也不需要修改,
可以说是拿来即用
使用: pip install torchfurnace
github: https://github.com/tianyu-su/torchfurnace
下面的例子是快速搭建训练,使用 VGG16 训练 CIFIAR10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR
from torchfurnace import Engine, Parser
from torchfurnace.utils.function import accuracy
# define training process of your model
class VGGNetEngine(Engine):
@staticmethod
def _on_forward(training, model, inp, target, optimizer=None) -> dict:
ret = {'loss': object, 'acc1': object, 'acc5': object}
output = model(inp)
loss = F.cross_entropy(output, target)
if training:
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc1, acc5 = accuracy(output, target, topk=(1, 5))
ret['loss'] = loss.item()
ret['acc1'] = acc1.item()
ret['acc5'] = acc5.item()
return ret
@staticmethod
def _get_lr_scheduler(optim) -> list:
return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]
def main():
# define experiment name
parser = Parser('TVGG16')
args = parser.parse_args()
experiment_name = '_'.join([args.dataset, args.exp_suffix])
# Data
ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = CIFAR10(root='data', train=True, download=True, transform=ts)
testset = CIFAR10(root='data', train=False, download=True, transform=ts)
# define model and optimizer
net = models.vgg16(pretrained=False, num_classes=10)
net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
net.classifier = nn.Linear(512, 10)
optimizer = torch.optim.Adam(net.parameters())
# new engine instance
eng = VGGNetEngine(parser,experiment_name)
acc1 = eng.learning(net, optimizer, trainset, testset)
print('Acc1:', acc1)
if __name__ == '__main__':
import sys
run_params = '--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr'
sys.argv.extend(run_params.split())
main()