• mxnet包含NDArray的列表更新


    Jul 26, 2017

    之前写的用来人工设定batch_sizeacc_xxx发现出现了问题。最终发现是列表更新的问题。
    想想之前的NDArray处理,也是奇葩了。比如,你能告诉我下面这段中注释的与非注释的,产生差别的原理么?...what?!居然会有差别?You kidding?

        def acc_update(self,normsize=1):
            assert self.binded and self.params_initialized and self.optimizer_initialized
    #        self._curr_module._exec_group.grad_arrays=
    #                      [[grad.copyto(grad.context)/normsize if grad is not None else None for grad in grads] for grads in self.grad]
            for acc_grads, mod_grads in zip(self.grad,self._curr_module._exec_group.grad_arrays):
                for acc_grad, mod_grad in zip(acc_grads, mod_grads):
                    if acc_grad is not None:
                        mod_grad = acc_grad.copyto(mod_grad.context)/normsize
                    else:
                        mod_grad = None
            ...
    

    Oct 22, 2017

    最近发现接口又改了((⊙﹏⊙)b),新版的(V0.11.1)里面这样做也不合适,用分片的方法可能是对的(从一些结果上来看,还不能肯定没问题)。

    Oct 23, 2017

    对比了累计更新和一次更新作为一个batch的输出,初步验证程序的正确性。


    两处的目的都很明显:想用self.grad的内容更新self._curr_module._exec_group.grad_arrays
    然而调试的结果是,没被注释掉的能够完成这项预期,另外一个不能(可能是暂时的)归纳出其规律,表现某种单一增长的特征。
    感觉上应该是列表之间的替换,但却没有这样运行
    后面再看问题出在哪?

    Sep 13, 2017

    没有发现可能的问题,打算先放一放了。开始的时候打算从_exec_group.grad_arays的接口入手,发现是从_exec_group._execs[].grad_array中传过来的,但在update的时候,用的是前者,猜测可能在更新前有过同步,在没有找到的情况下,直接将后者del或者赋值为None,但都没有效果,和昨天的情况保持了相同;此外,将上段程序中的normsize设置为0发现,更新后确实也没有显著变化(细微的变化应该是由weight decay引起的——仅在小数点变化),也就是说,acc_update中的self.grad确实起到了作用。所以被凌乱了...mess
    贴上两次的结果对比吧,以伺观code者得焉

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    1444888.8
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    4637960.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    11257292.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    22884572.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    41182624.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    

    上面这段应该是非正常结果的,下面这段是归为正常结果。

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    741699.12
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    33154.039
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.383278
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.155006
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    
  • 相关阅读:
    套接字编程,创建套接字socket
    网络编程基本原理
    进一步学习的书籍
    C# 基础备忘录
    二进制转文件以及文件压缩和解压缩
    使用用WCF中的双工(Duplex)模式将广告图片推送到每个Winform客户端机子上
    C#两个日期范围内的间隔
    C#中XML文档注释编译DLL引用到其它项目
    用T4模版生成对应数据库表的实体类
    lodop打印控件需要开启的几个计算机服务
  • 原文地址:https://www.cnblogs.com/chenyliang/p/7512347.html
Copyright © 2020-2023  润新知