• pytorch/Python的一些函数用法(日常更新)


    torch.nn.Embedding(num_embeddings: int, embedding_dim: int)
    是用来embed词成为word embedding的
    num_embeddings (int): size of the dictionary of embeddings
    embedding_dim (int): the size of each embedding vector

    例如:self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)




    Python内置函数:getattr(object, name[, default])
    • object -- 对象。
    • name -- 字符串,对象属性。
    • default -- 默认返回值,如果不提供该参数,在没有对应属性时,将触发 AttributeError。

    X.expand(size)

    用来将X原样复制size遍,成为size的shape,例如

    t
    Out[22]: tensor([0, 1, 2, 3])
    t.expand((2,3,-1))
    Out[23]:
    tensor([[[0, 1, 2, 3],
             [0, 1, 2, 3],
             [0, 1, 2, 3]],
            [[0, 1, 2, 3],
             [0, 1, 2, 3],
             [0, 1, 2, 3]]])

    register_buffer

    应该就是在内存中定义一个常量,同时,模型保存和加载的时候可以写入和读出。

    例如:self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    tensor.permute()

    维度转换

    比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

    利用这个函数permute(0,2,1)可以把Tensor([[[1,2,3],[4,5,6]]]) 转换成

    1. tensor([[[1., 4.],
    2. [2., 5.],
    3. [3., 6.]]])

    如果使用view,可以得到

    1. tensor([[[1., 2.],
    2. [3., 4.],
    3. [5., 6.]]])

    tensor.view(-1)

    把tensor中所有数字放置成一个list返回

    import torch
    a = torch.randn(3,5,2)
    print(a)
    print(a.view(-1))
    

    运行结果:

    tensor([[[-0.6887,  0.2203],
             [-1.6103, -0.7423],
             [ 0.3097, -2.9694],
             [ 1.2073, -0.3370],
             [-0.5506,  0.4753]],
    
            [[-1.3605,  1.9303],
             [-1.5382, -1.0865],
             [-0.9208, -0.1754],
             [ 0.1476, -0.8866],
             [ 0.4519,  0.2771]],
    
            [[ 0.6662,  1.1027],
             [-0.0912, -0.6284],
             [-1.0253, -0.3542],
             [ 0.6909, -1.3905],
             [-2.1140,  1.3426]]])
    tensor([-0.6887,  0.2203, -1.6103, -0.7423,  0.3097, -2.9694,  1.2073, -0.3370,
            -0.5506,  0.4753, -1.3605,  1.9303, -1.5382, -1.0865, -0.9208, -0.1754,
             0.1476, -0.8866,  0.4519,  0.2771,  0.6662,  1.1027, -0.0912, -0.6284,
            -1.0253, -0.3542,  0.6909, -1.3905, -2.1140,  1.3426])

    Optional[X]

    等价于Union[X, None]

    from typing import Optional
    
    def foo_v2(a: int, b: Optional[int] = None):
        if b:
            print(a + b)
        else:
            print("parameter b is a NoneType!")
    
    #只传入a位置的实参
    foo_v2(2)
    
    # 输出
    >>> parameter b is a NoneType!

    d

  • 相关阅读:
    状压DP【p1879】[USACO06NOV]玉米田Corn Fields
    Tarjan缩点+Spfa最长路【p3627】[APIO2009] 抢掠计划
    Tarjan缩点【p1726】上白泽慧音
    分层图【p4822】[BJWC2012]冻结
    Tarjan缩点+LCA【p2783】有机化学之神偶尔会做作弊
    线段树【p1607】[USACO09FEB]庙会班车Fair Shuttle
    better-scroll踩坑合集
    在浏览器上安装 Vue Devtools工具
    无法执行vue初始化命令
    vue-cli创建第一个项目(用git bash解决上下键移动选择问题)
  • 原文地址:https://www.cnblogs.com/gagaein/p/14391853.html
Copyright © 2020-2023  润新知