• pytorch中torch.chunk()方法


    chunk方法可以对张量分块,返回一个张量列表:

    torch.chunk(tensorchunksdim=0) → List of Tensors

    Splits a tensor into a specific number of chunks.

    Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.(如果指定轴的元素个数被chunks除不尽,那么最后一块的元素个数变少)

    Parameters:
    • tensor (Tensor) – the tensor to split
    • chunks (int) – number of chunks to return(分割的块数)
    • dim (int) – dimension along which to split the tensor(沿着哪个轴分块)
     
    import numpy as np
    import torch
    
    data = torch.from_numpy(np.random.rand(3, 5))
    print(str(data))
    >>
    tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590],
            [0.9705, 0.8673, 0.8854, 0.9029, 0.5473],
            [0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)
    
    for i, data_i in enumerate(data.chunk(5, 1)): # 沿1轴分为5块
        print(str(data_i))
    >>
    tensor([[0.6742],
            [0.9705],
            [0.0199]], dtype=torch.float64)
    tensor([[0.5700],
            [0.8673],
            [0.4729]], dtype=torch.float64)
    tensor([[0.3519],
            [0.8854],
            [0.4001]], dtype=torch.float64)
    tensor([[0.4603],
            [0.9029],
            [0.7581]], dtype=torch.float64)
    tensor([[0.9590],
            [0.5473],
            [0.5045]], dtype=torch.float64)  
    
    for i, data_i in enumerate(data.chunk(3, 0)): # 沿0轴分为3块
        print(str(data_i))
    >>
    tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590]], dtype=torch.float64)
    tensor([[0.9705, 0.8673, 0.8854, 0.9029, 0.5473]], dtype=torch.float64)
    tensor([[0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64)
       
    for i, data_i in enumerate(data.chunk(3, 1)): # 沿1轴分为3块,除不尽
        print(str(data_i))
    >>
    tensor([[0.6742, 0.5700],
            [0.9705, 0.8673],
            [0.0199, 0.4729]], dtype=torch.float64)
    tensor([[0.3519, 0.4603],
            [0.8854, 0.9029],
            [0.4001, 0.7581]], dtype=torch.float64)
    tensor([[0.9590],
            [0.5473],
            [0.5045]], dtype=torch.float64)

     

  • 相关阅读:
    iOS之App Store上架被拒Legal
    iOS之解决崩溃Collection <__NSArrayM: 0xb550c30> was mutated while being enumerated.
    iOS之延时执行(睡眠)的几种方法
    iOS之UILabel的自动换行
    iOS之开发中一些相关的路径以及获取路径的方法
    iOS之绘制虚线
    iOS之判断手机号码、邮箱格式是否正确
    iOS之计算上次日期距离现在多久, 如 xx 小时前、xx 分钟前等
    iOS之开发中常用的颜色及其对应的RGB值
    method invocation
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/10309766.html
Copyright © 2020-2023  润新知