CVPR2021文章adco官方代码阅读笔记
1. 正样本相似性计算用的是矩阵乘法,即同一个batch的其它样本也做为负样本(而不仅仅是memory bank中的样本)
l_pos = torch.enisum('nc,ck->nk', [q, k.T]) # train.py update_network()
而不是像moco一样直接做向量点积
l_pos = torch.enisum('nc,nc->n', [q, k])
作者在issue中说这样能更好的的避免collapsing。没明白为什么,欢迎在评论中指导。
2. memory bank的更新加入了momentum,而不是直接替换W
1 Memory_Bank.v.data = args.momentum * Memory_Bank.v.data + g + args.mem_wd * Memory_Bank.W.data 2 Memory_Bank.W.data = Memory_Bank.W.data - args.memory_lr * Memory_Bank.v.data
【参考作者的github issue[3].】