• pytorch网络转libtorch常见问题



    一、All inputs of range must be ints, found Tensor in argument 0:

    问题
    参数类型不正确,函数的默认参数是tensor

    解决措施
    函数传入参数不是tensor需要注明类型
    我的问题是传入参数npoint是一个int类型,没有注明会报错,更改如下:

    def test(npoint):
      ...
    

    更改为

    def test(npoint: int):
      ...
    

    二、Sliced expression not yet supported for subscripted assignment. File a bug if you want this:

    问题
    不支持赋值给切片表达式

    解决措施
    根据自己需求,进行修改,可利用循环替代

    我将view_shape[1:] = [1] * (len(view_shape) - 1)更改为

        for i in range(1, len(view_shape)):
            view_shape[i] = 1
    

    三、Tried to access nonexistent attribute or method 'len' of type 'torch.torch.nn.modules.container.ModuleList'. Did you forget to initialize an attribute in init()?

    问题
    forward函数中好像不支持len(nn.ModuleList())和下标访问

    解决措施
    如果是一个ModuleList()可以用enumerate函数,多个同维度的可以用zip函数

    我这里有两个ModuleList(),所以采用zip函数,更改如下:

       for i, conv in enumerate(self.mlp_convs):
          bn = self.mlp_bns[i]
          new_points = F.relu(bn(conv(new_points)))
    

    更改为

        for conv, bn in zip(self.mlp_convs, self.mlp_bns):
            new_points = F.relu(bn(conv(new_points)))
    

    ref: https://github.com/pytorch/pytorch/issues/16123


    四、Expected integer literal for index

    问题和解决方法类似第三个


    五、Arguments for call are not valid. The following variants are available

    Expected a value of type 'List[Tensor]' for argument 'indices' but instead found type 'List[Optional[Tensor]]'

    问题
    赋值类型不对,需求是tensor,但给的是int

    解决措施

    • 方法1
      int类型的数Ntorch.tensor(N)代替
    mask = sqrdists > radius ** 2
    group_idx[mask] = N
    

    变为

    mask = sqrdists > radius ** 2
    group_idx[mask] = torch.tensor(N)
    
    • 方法2 (速度较慢)
      for循环替代`
    mask = sqrdists > radius ** 2
    group_idx[mask] = N
    

    变为

    B, rows, cols = sqrdists.shape
    ref_redius = radius ** 2
    for b in range(B):
        for r in range(rows):
            print("r: ", r)
            for c in range(cols):
                if sqrdists[b][r][c] > ref_redius:
                    group_idx[b][r][c] = N
    
  • 相关阅读:
    如何正确夸奖孩子
    C# datatable分页和 list 分页
    js修改Switchery复选框的状态
    虚拟机中centos中设置固定IP
    CommonJS和ES6
    npm使用淘宝镜像
    RabbitMQ基础概念详细介绍
    Web漏洞扫描神器Nikto使用指南
    Redis基本使用
    ROS文件系统导览
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/15555066.html
Copyright © 2020-2023  润新知