• [pytorch]pytorch 将x按阈值条件进行变换/裁剪/映射


    pytorch 将x按阈值条件进行变换/裁剪/映射

    0. 场景

    在进行深度学习模型裁剪时,一个很显然的需求是将小于某个阈值的值全都设置为0,比如设置成阈值为0.5, 也即x[x<=0.5] = 0
    借助mask,很容易实现上述的需求,但推广起来,比如将对应位置设置为对应的fun,也即x[x<=0.5] = f1(x), x[x<=1.0 and x>0.5] = f2(x) 设计起来可能就没那么容易了。

    我们便遇到一个场景,要对不同的x实现不同的放缩,即x[x<=0.5] = w1 * x, x[x<=1.0 and x>0.5] = w2 * x, ...

    1. 方法

    首先需要将x的值进行映射,即将x映射到对应的id上,然后通过索引的方式找到对应的weight值,最后再相乘即可,下面是具体的代码实现:

    def magic_func(x):
        k = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]
        k = torch.tensor(k)
        
        idx = torch.tensor(x+8, dtype=torch.long)
        idx = torch.clamp(idx, min=0, max=15)
        weight_k = k[idx]
        
        y = weight_k.mul(x)
        
        return y
    

    通过这种方法,可以将x在[-8, -7]区间的置为1x,[-7, -6]的区间的置为2x。

  • 相关阅读:
    matlab 2021a 和 2021b共存的方案
    World Time Alighnment
    美化Xshell – 使用 Monokai 配色
    Centos 提示sudo: java: command not found解决办法
    typescript
    Spring MVC注册mapping
    Java执行JavaScript脚本
    vue3
    monaco editor
    rollup
  • 原文地址:https://www.cnblogs.com/wildkid1024/p/16273259.html
Copyright © 2020-2023  润新知