• [深度学习]PyTorch的文档阅读笔记


    torch

    • is_tensor
    • as_tensor,从arraylike创建tensor,共享内存,默认开启梯度计算
    • from_numpy,从ndarray拷贝tensor
    • zeros(size)
    • zeros_like()
    • ones(size)
    • ones_like()

    随机生成tensor:

    • bernoulli()
    • normal(mean,std)
    • poisson() \(out_i \sim Poisson(input_i)\)
    • rand(size) U(0,1)
    • rand_like()
    • randn(size) N(0,1)
    • randn_like()
    • randperm() 0到n-1的随机排列

    Tensor

    API

    与ndarray一致,可互相转化,不同在于Tensor可以在GPU上运行

    常用不同数据类型的Tensor:

    • torch.FloatTensor, torch.DoubleTensor
    • torch.ShortTensor, torch.IntTensor, torch.LongTensor

    默认为FloatTensor

    构造函数

    torch.tensor(data,dtype,device,requires_grad,pin_memory)

    参数列表:

    • data list,ruple或ndarray,或tensor
    • dtype(torch.dtype) tensor的类型,不填的话默认data的类型,可选项:torch.float,torch.int其他
    • device(torch.device) tensor的存储位置,不填若data是tensor则在同位置,否则默认在CPU上。常用值为:torch.device('cuda:0')。(笔者注:此参数是在CPU上训练模型用的)
    • requires_grad(bool)autograd用的,默认False,若需要使用,可以调用requires_grad_()来修改值
    • pin_memory(bool) 只对CPU类的Tensor有效,分配张量到固定内存(Pinned Memory),或称页锁定内存(Page-locked Memory),默认为False,开启后Tensor读取速度增加,需要大内存的计算机才能使用该参数。

    对于numpy转tensor,使用torch.as_tensor()避免数据拷贝

    Tensor.item()

    只对有一个元素的tensor有效,使用中括号索引tensor得到的结果依然是tensor,需要使用item()才能转成python的数字。

    >>> x = torch.tensor([1.0])
    >>> x.item()
    1.0
    

    Data

    Dataset

    torch.utils.data.Dataset内的抽象类,需要重载__len____getitem__两个函数。

    Map-style Dataset

    A map-style dataset is one that implements the __getitem__() and __len__() protocols, and represents a map from (possibly non-integral) indices/keys to data samples.

    For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label from a folder on the disk.

    Iterable-style datasets

    An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

    For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

    数据加载

    对于迭代类型的数据,数据加载顺序由用户定义的__iter__()决定。

    对于映射型数据,使用torch.utils.data.Sampler来加载数据。

    注:这块内容torch的docs讲得不清楚,还需结合具体项目学习

    Dataloader

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
               batch_sampler=None, num_workers=0, collate_fn=None,
               pin_memory=False, drop_last=False, timeout=0,
               worker_init_fn=None, *, prefetch_factor=2,
               persistent_workers=False)
    
    • dataset (Dataset) – 数据集
    • batch_size (int, optional) – 数据分批,每批多少个
    • shuffle (bool, optional) – 是否打乱
    • collate_fn (callable, optional) – 从关联式数据集分批读取时使用,可填default_collate

    nn.Module

    模型的抽象类,模板:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))
    

    torch.optim

    里面有很多优化算法,需要先建立一个optimizer对象然后它会根据计算出的梯度来更新参数。

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = optim.Adam([var1, var2], lr=0.0001)
    
    

    支持算法:Adadelta,Adagrad,Adam,AdamW,SparseAdam,Adamax,ASGD,LBFGS,NAdam,RAdam,RMSprop,Rprop,SGD

    模型加载和保存

    保存的对象是model类

    • torch.save(model,'./model.pth')
    • load_model = torch.load('model.pth')
  • 相关阅读:
    【BZOJ5298】【CQOI2018】交错序列(矩阵快速幂优化dp)
    【BZOJ5297】【CQOI2018】社交网络(有向图生成树计数)
    【BZOJ5296】【CQOI2018】破解D-H协议(BSGS)
    【BZOJ1185】【HNOI2007】最小矩形覆盖(凸包+旋转卡壳)
    【BZOJ1069】【SCOI2007】—最大土地面积(凸包+旋转卡壳)
    【BZOJ2300】【HAOI2011】—防线修建(set维护动态凸包)
    【POJ1912】【Ceoi2002】—A highway and the seven dwarfs(凸包)
    【BZOJ1043】【HAOI2008】—下落的圆盘(圆的并集)
    node-多进程
    Node-RESTful
  • 原文地址:https://www.cnblogs.com/sherrlock/p/16288875.html
Copyright © 2020-2023  润新知