chunk方法可以对张量分块,返回一个张量列表:
torch.
chunk
(tensor, chunks, dim=0) → List of Tensors
Splits a tensor into a specific number of chunks.
Last chunk will be smaller if the tensor size along the given dimension dim
is not divisible by chunks
.(如果指定轴的元素个数被chunks除不尽,那么最后一块的元素个数变少)
Parameters: |
---|
import numpy as np import torch data = torch.from_numpy(np.random.rand(3, 5)) print(str(data)) >> tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590], [0.9705, 0.8673, 0.8854, 0.9029, 0.5473], [0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64) for i, data_i in enumerate(data.chunk(5, 1)): # 沿1轴分为5块 print(str(data_i)) >> tensor([[0.6742], [0.9705], [0.0199]], dtype=torch.float64) tensor([[0.5700], [0.8673], [0.4729]], dtype=torch.float64) tensor([[0.3519], [0.8854], [0.4001]], dtype=torch.float64) tensor([[0.4603], [0.9029], [0.7581]], dtype=torch.float64) tensor([[0.9590], [0.5473], [0.5045]], dtype=torch.float64) for i, data_i in enumerate(data.chunk(3, 0)): # 沿0轴分为3块 print(str(data_i)) >> tensor([[0.6742, 0.5700, 0.3519, 0.4603, 0.9590]], dtype=torch.float64) tensor([[0.9705, 0.8673, 0.8854, 0.9029, 0.5473]], dtype=torch.float64) tensor([[0.0199, 0.4729, 0.4001, 0.7581, 0.5045]], dtype=torch.float64) for i, data_i in enumerate(data.chunk(3, 1)): # 沿1轴分为3块,除不尽 print(str(data_i)) >> tensor([[0.6742, 0.5700], [0.9705, 0.8673], [0.0199, 0.4729]], dtype=torch.float64) tensor([[0.3519, 0.4603], [0.8854, 0.9029], [0.4001, 0.7581]], dtype=torch.float64) tensor([[0.9590], [0.5473], [0.5045]], dtype=torch.float64)