• pytorch记录


    有两个tensor是A和B

    C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼) C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
     A = torch.ones(2,3)
     B = torch.ones(4,3)
     out=torch.cat((A,B),0)
    tensor([[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]])
    
    
    C = torch.ones(2,5)
    out = torch.cat((A,C),1)
    tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.]])
     max_test = torch.Tensor([[5,8,1],[3,1,9]])
    tensor([[5., 8., 1.],
            [3., 1., 9.]])
    
    max_test.max(1,keepdim=True)
    values=tensor([[8.],
            [9.]]),
    indices=tensor([[1],
            [2]]))
    
     max_test.max(1)
    torch.return_types.max(
    values=tensor([8., 9.]),
    indices=tensor([1, 2]))
    
    max_test.max(0)
    values=tensor([5., 8., 9.]),
    indices=tensor([0, 0, 1]))
    
    max_test.max(0,keepdim=True)
    torch.return_types.max(
    values=tensor([[5., 8., 9.]]),
    indices=tensor([[0, 0, 1]]))
    valid_idx = torch.tensor([True, False, True, False, False]) #小写的t,long类型
    a = torch.tensor([1,2,3,4,5])
    idx_filter = a[valid_idx]
    tensor([1, 3])
    b = torch.Tensor([[1,2,3]])
    b.squeeze(0)
     b
    tensor([[1., 2., 3.]])
    
    b.squeeze_(0)
    b
    tensor([1., 2., 3.])
    a = torch.ones(3,5)
    index = torch.tensor([0,2])
    a.index_fill_(0,index,100)
    tensor([[100., 100., 100., 100., 100.],
            [  1.,   1.,   1.,   1.,   1.],
            [100., 100., 100., 100., 100.]])
    
    
    b = torch.ones(3,5)
    b.index_fill(1,index,200)
    tensor([[200.,   1., 200.,   1.,   1.],
            [200.,   1., 200.,   1.,   1.],
            [200.,   1., 200.,   1.,   1.]])
     labels= torch.rand(5,4)
    tensor([[0.2833, 0.7600, 0.6912, 0.5421],
            [0.3498, 0.0440, 0.3356, 0.5975],
            [0.9071, 0.2023, 0.9391, 0.2516],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.2392, 0.7111, 0.9192, 0.5417]])
     best_idx = torch.tensor([3,3,3,0,0,0,0])
    labels[best_idx]
    tensor([[0.9536, 0.0939, 0.4833, 0.7402],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.9536, 0.0939, 0.4833, 0.7402],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421],
            [0.2833, 0.7600, 0.6912, 0.5421]])
  • 相关阅读:
    Global.asax 事件备忘
    JavaScript异常捕捉
    还记得 virtual 吗?我们来温故知新下吧。
    开发(ASP.NET程序)把写代码写至最有面向对象味道
    MVC中实现 "加载更多..."
    js和C#中的编码和解码(备忘)
    System.AccessViolationException: 尝试读取或写入受保护的内存。这通常指示其他内存已损坏
    10种提高WordPress访问速度的方法
    使用Python3实现Telnet功能
    读书计划(不断更新)201904
  • 原文地址:https://www.cnblogs.com/crazybird123/p/14686357.html
Copyright © 2020-2023  润新知