• 软工划水日报-paddle模型训练(3) 4/20


    今天来写正经八百的训练函数文件,然后开始训练!

    我的机器发出了异常的巨响,于是借用了一台散热器,现在写博客园只能开这一个页面,QQ都不能挂……

    我头一次感觉到我的电脑在性能方面不行了?!

    以下是代码:

    import os
    import shutil
    import webnet
    import paddle as paddle
    import firstdo
    import paddle.fluid as fluid
    
    paddle.enable_static()
    
    crop_size = 224
    resize_size = 250
    
    # 定义输入层
    image = fluid.layers.data(name='image', shape=[3, crop_size, crop_size], dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int64')
    
    # 获取分类器
    model = webnet.net(image, 61)
    
    # 获取损失函数和准确率函数
    cost = fluid.layers.cross_entropy(input=model, label=label)
    avg_cost = fluid.layers.mean(cost)
    acc = fluid.layers.accuracy(input=model, label=label)
    
    # 获取训练和测试程序
    test_program = fluid.default_main_program().clone(for_test=True)
    
    # 定义优化方法
    optimizer = fluid.optimizer.AdamOptimizer(learning_rate=1e-3,
                                              regularization=fluid.regularizer.L2DecayRegularizer(1e-4))
    opts = optimizer.minimize(avg_cost)
    
    database = 'C:/Users/14997/Desktop/database/'
    
    # 获取自定义数据
    train_reader = paddle.batch(reader=firstdo.train_reader(database+'train.list', crop_size, resize_size), batch_size=32)
    test_reader = paddle.batch(reader=firstdo. test_reader(database+'test.list', crop_size), batch_size=32)
    
    # 定义一个使用GPU的执行器
    # 这里可以用CPU,但奇慢无比,预计跑完全部的流程可能要200+小时
    # 但使用GPU对内存要求太大了
    place = fluid.CUDAPlace(0)
    # place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    # 进行参数初始化
    exe.run(fluid.default_startup_program())
    
    # 定义输入数据维度
    feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
    
    # 训练100次
    for pass_id in range(50):
        # 进行训练
        for batch_id, data in enumerate(train_reader()):
            train_cost, train_acc = exe.run(program=fluid.default_main_program(),
                                            feed=feeder.feed(data),
                                            fetch_list=[avg_cost, acc])
    
            # 每100个batch打印一次信息
            if batch_id % 100 == 0:
                print('Pass:%d, Batch:%d, Cost:%0.5f, Accuracy:%0.5f' %
                      (pass_id, batch_id, train_cost[0], train_acc[0]))
                # 保存预测模型
                save_path = 'infer_model/'
                # 删除旧的模型文件
                shutil.rmtree(save_path, ignore_errors=True)
                # 创建保持模型文件目录
                os.makedirs(save_path)
                # 保存预测模型
                fluid.io.save_inference_model(save_path, feeded_var_names=[image.name], target_vars=[model], executor=exe)
    
        # 进行测试
        test_accs = []
        test_costs = []
        for batch_id, data in enumerate(test_reader()):
            test_cost, test_acc = exe.run(program=test_program,
                                          feed=feeder.feed(data),
                                          fetch_list=[avg_cost, acc])
            test_accs.append(test_acc[0])
            test_costs.append(test_cost[0])
        # 求测试结果的平均值
        test_cost = (sum(test_costs) / len(test_costs))
        test_acc = (sum(test_accs) / len(test_accs))
        print('Test:%d, Cost:%0.5f, Accuracy:%0.5f' % (pass_id, test_cost, test_acc))
  • 相关阅读:
    个人亲历运维面试
    《Kubernetes进阶实战》之管理Pod资源对象
    Docker容器必备技能 -- iptables
    vue后台管理权限正确思路
    Axios 各种请求方式传递参数格式
    Cookie的使用(js-cookie插件)
    微信小程序template模板与component组件的区别和使用
    如何机智地回答浏览器兼容性问题
    webpack系列5:源码流程,webpack编译流程
    webpack系列4:文件分析.
  • 原文地址:https://www.cnblogs.com/Sakuraba/p/14909993.html
Copyright © 2020-2023  润新知