• mxnet包含NDArray的列表更新


    Jul 26, 2017

    之前写的用来人工设定batch_sizeacc_xxx发现出现了问题。最终发现是列表更新的问题。
    想想之前的NDArray处理,也是奇葩了。比如,你能告诉我下面这段中注释的与非注释的,产生差别的原理么?...what?!居然会有差别?You kidding?

        def acc_update(self,normsize=1):
            assert self.binded and self.params_initialized and self.optimizer_initialized
    #        self._curr_module._exec_group.grad_arrays=
    #                      [[grad.copyto(grad.context)/normsize if grad is not None else None for grad in grads] for grads in self.grad]
            for acc_grads, mod_grads in zip(self.grad,self._curr_module._exec_group.grad_arrays):
                for acc_grad, mod_grad in zip(acc_grads, mod_grads):
                    if acc_grad is not None:
                        mod_grad = acc_grad.copyto(mod_grad.context)/normsize
                    else:
                        mod_grad = None
            ...
    

    Oct 22, 2017

    最近发现接口又改了((⊙﹏⊙)b),新版的(V0.11.1)里面这样做也不合适,用分片的方法可能是对的(从一些结果上来看,还不能肯定没问题)。

    Oct 23, 2017

    对比了累计更新和一次更新作为一个batch的输出,初步验证程序的正确性。


    两处的目的都很明显:想用self.grad的内容更新self._curr_module._exec_group.grad_arrays
    然而调试的结果是,没被注释掉的能够完成这项预期,另外一个不能(可能是暂时的)归纳出其规律,表现某种单一增长的特征。
    感觉上应该是列表之间的替换,但却没有这样运行
    后面再看问题出在哪?

    Sep 13, 2017

    没有发现可能的问题,打算先放一放了。开始的时候打算从_exec_group.grad_arays的接口入手,发现是从_exec_group._execs[].grad_array中传过来的,但在update的时候,用的是前者,猜测可能在更新前有过同步,在没有找到的情况下,直接将后者del或者赋值为None,但都没有效果,和昨天的情况保持了相同;此外,将上段程序中的normsize设置为0发现,更新后确实也没有显著变化(细微的变化应该是由weight decay引起的——仅在小数点变化),也就是说,acc_update中的self.grad确实起到了作用。所以被凌乱了...mess
    贴上两次的结果对比吧,以伺观code者得焉

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    1444888.8
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    4637960.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    11257292.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    22884572.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    41182624.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    

    上面这段应该是非正常结果的,下面这段是归为正常结果。

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    741699.12
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    33154.039
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.383278
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.155006
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    
  • 相关阅读:
    14.4 exportfs命令 14.5 NFS客户端问题 15.1 FTP介绍 15.2/15.3 使用vsftpd搭建ftp
    14.1 NFS介绍 14.2 NFS服务端安装配置 14.3 NFS配置选项
    13.4 mysql用户管理 13.5 常用sql语句 13.6 mysql数据库备份恢复
    13.1 设置更改root密码 13.2 连接mysql 13.3 mysql常用命令
    12.21 php-fpm的pool 12.22 php-fpm慢执行日志 12.23 open_basedir 12.24 php-fpm进程管理
    12.17 Nginx负载均衡 12.18 ssl原理 12.19 生成ssl密钥对 12.20 Nginx配置ssl
    12.13 Nginx防盗链 12.14 Nginx访问控制 12.15 Nginx解析php相关配置 12.16 Nginx代理
    在界面workspacebar上添加的OnContextMenu函数打的断点始终进不去,显示当前不会命中断点还未为文档加载任何符号
    VS2019安装闪退, 不出现安装界面,解决办法
    BCGP单元格多列合并需要注意的
  • 原文地址:https://www.cnblogs.com/chenyliang/p/7512347.html
Copyright © 2020-2023  润新知