tensor的拼接与拆分
cat函数
例子:成绩单的合并
【班级1~4 学生 得分】
【班级5~9 学生 得分】
在0维进行合并,非cat的维度必须一致
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]
stack函数
会新添加一个维度,要保证两个stack的tensor的维度一摸一样
,在理解方面是添加了新的概念在里面。
例子:
一班:【32个学生 每个学生8门课程】
二班:【32个学生 每个学生8门课程】
stack之后变为【两个班级 每个班级32个学生 每个学生有8门课程】
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]
split函数
split函数按照长度来拆分
例子1:
参数说明:【1,1】表示前面的长度为1,后面的长度也是1
a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]
例子2:
参数说明:【2,1】表示前面的长度为2,后面的长度为1
(不规则分割的参数含义)
a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]
chunk函数
根据数量来进行分割(尽量实现整除,后面除不尽的留给最后)
例子:
a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
例子2:
a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])
例子3:
a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)
#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])