• Pytorch的torch.cat实例


    import torch
    

      

    通过 help((torch.cat)) 可以查看 cat 的用法
    cat(seq,dim,out=None)
     
    其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b),  a,b 为两个可以连接的序列
    dim 表示以哪个维度连接,dim=0, 横向连接
                          dim=1,纵向连接
     
    

     

    #实例:
     
        #dim=0 时:
        
        import torch
        n_data = torch.ones((100,2))
        x0_data = torch.normal(2*n_data,1)
        y0_data = torch.zeros((100,1))
        x1_data = torch.normal(-2*n_data,1)
        y1_data = torch.ones((100,1))
        x_data = torch.cat((x0_data,x1_data),0).type(torch.FloatTensor)
        y_data = torch.cat((y0_data,y1_data),0).type(torch.LongTensor)
        print('x_data的形状:',x_data.shape)
        print("y_data的形状:",y_data.shape)
    

      

    result:
        
        x_data的形状: torch.Size([200, 2])
        y_data的形状: torch.Size([200, 1])
    

      

    #实例:
     
        #dim=1 时:
        
        import torch
        n_data = torch.ones((100,2))
        x0_data = torch.normal(2*n_data,1)
        y0_data = torch.zeros((100,1))
        x1_data = torch.normal(-2*n_data,1)
        y1_data = torch.ones((100,1))
        x_data = torch.cat((x0_data,x1_data),1).type(torch.FloatTensor)
        y_data = torch.cat((y0_data,y1_data),1).type(torch.LongTensor)
        print('x_data的形状:',x_data.shape)
        print("y_data的形状:",y_data.shape)
    

      

    result:
     
        x_data的形状: torch.Size([100, 4])
        y_data的形状: torch.Size([100, 2])
    

      

  • 相关阅读:
    BUPT复试专题—最长连续等差子数列(2014软院)
    BUPT复试专题—奇偶求和(2014软件)
    BUPT复试专题—网络传输(2014网研)
    Hopscotch(POJ 3050 DFS)
    Backward Digit Sums(POJ 3187)
    Smallest Difference(POJ 2718)
    Meteor Shower(POJ 3669)
    Red and Black(poj 1979 bfs)
    测试
    Equations(hdu 1496 二分查找+各种剪枝)
  • 原文地址:https://www.cnblogs.com/pythonClub/p/10412418.html
Copyright © 2020-2023  润新知