• Pytorch基本变量类型FloatTensor与Variable


    pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息

    pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现

    然而正确的打开方式并不是这样

    韩国一位大神写了一个pytorch的turorial,其中包含style transfer的一个代码实现

     1     for step in range(config.total_step):
     2         
     3         # Extract multiple(5) conv feature vectors
     4         target_features = vgg(target)   # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
     5         content_features = vgg(Variable(content))
     6         style_features = vgg(Variable(style))
     7 
     8         style_loss = 0
     9         content_loss = 0
    10         for f1, f2, f3 in zip(target_features, content_features, style_features):
    11             # Compute content loss (target and content image)
    12             content_loss += torch.mean((f1 - f2)**2)  # square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作
    13 
    14             # Reshape conv features
    15             _, c, h, w = f1.size()  # channel height width
    16             f1 = f1.view(c, h * w)  # reshape a vector
    17             f3 = f3.view(c, h * w)  # reshape a vector
    18 
    19             # Compute gram matrix  
    20             f1 = torch.mm(f1, f1.t())
    21             f3 = torch.mm(f3, f3.t())
    22 
    23             # Compute style loss (target and style image)
    24             style_loss += torch.mean((f1 - f3)**2) / (c * h * w)   # 总共元素的数目?

    其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后

     1 # -*-coding: utf-8 -*-
     2 import torch
     3 from torch.autograd import Variable
     4 
     5 # dtype = torch.FloatTensor
     6 dtype = torch.cuda.FloatTensor  # Uncomment this to run on GPU
     7 
     8 # N is batch size; D_in is input dimension;
     9 # H is hidden dimension; D_out is output dimension.
    10 N, D_in, H, D_out = 64, 1000, 100, 10
    11 
    12 # Randomly initialize weights
    13 w1 = torch.randn(D_in, H).type(dtype)  # 两个权重矩阵
    14 w2 = torch.randn(D_in, H).type(dtype)
    15 # operate with +-*/ and **
    16 w3 = w1-2*w2
    17 w4 = w3**2
    18 w5 = w4/w1
    19 
    20 
    21 # operate the Variable with +-*/ and **
    22 w6 = Variable(torch.randn(N, D_in).type(dtype))
    23 w7 = Variable(torch.randn(N, D_in).type(dtype))
    24 w8 = w6 + w7
    25 w9 = w6*w7
    26 w10 = w9**2
    27 print(1)

    基本上调试的结果与预期相符

    所以,对于floattensor以及variable进行普通的+-×/以及**没毛病

  • 相关阅读:
    Soldier and Number Game素数筛
    HDU1501Zipper字符串的dfs
    HDU1285 确定比赛名次 拓扑排序模板题
    HDU1595 find the longest of the shortest dijkstra+记录路径
    HDU1556 Color the ball 前缀和/线段树/树状数组
    Function Run Fun递归+细节处理
    数学公式
    日常 java+雅思+训练题1
    HDU1423Greatest Common Increasing Subsequence
    HDU1595find the longest of the shortestdijkstra+记录路径
  • 原文地址:https://www.cnblogs.com/yongjieShi/p/8146028.html
Copyright © 2020-2023  润新知