• torch笔记合集


    Torch笔记

    import torch
    import numpy as np
    import torch.nn as nn
    
    a_np = np.random.rand(10,100)
    

    numpy知识回顾

    a_np.dtype  # 数据类型
    a_np.ndim #维度个数
    a_np.shape # 形状   整数元祖
    a_np.dtype=np.int32  # 修改数据类型
    

    获取tensor基本信息

    tensor_a = torch.from_numpy(a_np)
    tensor_a.type() # 获取tensor类型
    tensor_a.size() # 获取维度特征 并返回数组
    tensor_a.size(0) #获取第一个维度信息
    tensor_a.dim() # 获取维度个数
    tensor_a.device # 返回设备类型
    

    tensor数据类型转换

    tensor_b = tensor_a.float()  
    tensor_b = tensor_a.to(torch.float) 
    

    设备类型转换

    tensor_b = tensor_b.cuda()
    tensor_b = tensor_b.cpu()
    tensor_b = tensor_a.to(torch.device(0)) 
    

    tensor 转换到 numpy数据

    tensor_b = tensor_b.numpy()
    # 注意*gpu上的tensor不能直接转为numpy
    ndarray = tensor_b.cpu().numpy()
    # numpy 转 torch.Tensor
    tensor = torch.from_numpy(ndarray) 
    

    tensor的维度变换

    • squeeze_
    tensor_a.t_().size() # 对tensor进行转置
    tensor_a.t_().size()
    tensor_a.unsqueeze_(2) #
    unsqu_tensor = tensor_a.squeeze_() #  将tensor中维度元素数为1 的全部压缩掉
    
    • reshape()
    tensor = torch.reshape(tensor_a, [50,20])
    # 也可以用tensor.reshape(50,20)
    tensor.size()
    
    • tanspose()
    tensor.transpose(1,2)  #交换tensor的1,2维度
    

    从只包含一个元素的张量中提取值

    tensor_a = tensor_a.cuda()
    tensor_a[1][1].item()
    tensor_a.device
    

    拼接张量

    • 例如当参数是3个10×5的张量,torch.cat的结果是30×5的张量,而torch.stack的结果是3×10×5的张量。
    temp1 = torch.from_numpy(np.random.rand(5,4))
    temp2 = torch.from_numpy(np.random.rand(5,4))
    
    temp1.size()
    temp2.size()
    temp3 = torch.stack([temp1,temp2],dim=0)
    temp4 = torch.cat([temp1,temp2],dim=0)
    
    temp3.size()
    temp4.size()
    

    矩阵运算

    • 矩阵乘法
    # Matrix multiplication: (m*n) * (n*p) -> (m*p).
    result = torch.mm(tensor1, tensor2)
    NOTE:This function does not broadcast. For broadcasting matrix products, see torch.matmul().
    result = torch.matmul(tensor1, tensor2)  # 支持broadcast
    
    • tensor中对应元素相乘
    tensor1*tensor2  # 支持broadcast
    
    • 矩阵求最大组索引
    res,index = torch.max(tensor,dim=-1)
    
    • 维度理解 :填入哪个维度则哪个维度消失 比如sum,softmax

    • 求平均值 ==avg_pooling
    avg = torch.mean(tensor,dim=-1)
    

    打印模型信息

    • 参数量
    • 模型结构 使用torch summary
    class myNet(nn.Module):
        def __init__(self,*other_para):
            super(myNet,self).__init__()
            self.embedding_layer = nn.Embedding(10,3)
    
    net = myNet()
    num_parameters = sum(torch.numel(parameter) for parameter in net.parameters())
    from torchsummary import summary
    summary(net,input_size=(2,2))
    

    模型初始化

    # Common practise for initialization.
    for layer in model.modules():
        if isinstance(layer, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out',
                                          nonlinearity='relu')
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, val=0.0)
        elif isinstance(layer, torch.nn.BatchNorm2d):
            torch.nn.init.constant_(layer.weight, val=1.0)
            torch.nn.init.constant_(layer.bias, val=0.0)
        elif isinstance(layer, torch.nn.Linear):
            torch.nn.init.xavier_normal_(layer.weight)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, val=0.0)
    
    # Initialization with given tensor.
    layer.weight = torch.nn.Parameter(tensor)
    

    计算Softmax输出的准确率

    score = model(images)
    prediction = torch.argmax(score, dim=1)
    num_correct = torch.sum(prediction == labels).item()
    accuruacy = num_correct / labels.size(0)
    

    保存模型

    • torch.save(my_model.state_dict(), "params.pkl")

    加载模型

    • 先初始化model网络结构
    • model.load_state_dict(torch.load("params.pkl"))

    一些奇怪的注意事项

    • torch.nn.CrossEntropyLoss的输入不需要经过Softmax。torch.nn.CrossEntropyLoss等价于torch.nn.functional.log_softmax + torch.nn.NLLLoss。
    • torch的y不用转成onehot 它会自动帮你转 这点比较怪异,相比之下keras在多分类时 y必须是onehot的才能计算损失
    • model.train() :启用 BatchNormalization 和 Dropout
    • model.eval() :不启用 BatchNormalization 和 Dropout

    tricks

    • 用del及时删除不用的中间变量,节约GPU存储。
    • 使用inplace操作可节约GPU存储,如
  • 相关阅读:
    jdbc连接Mysql数据库
    测试ibatis3连接数据
    dbcp参数配置
    努力---是永远且持久的行为
    android---textview控件学习笔记之显示表情图片和文本(二)
    android---textview控件学习笔记之显示文本(一)
    程序员的要求
    android的adb命令中,pm,am的使用
    完成celery简单发送注册邮件
    培养代码逻辑
  • 原文地址:https://www.cnblogs.com/rise0111/p/11931507.html
Copyright © 2020-2023  润新知