• 读sru代码


    1、

    def read_corpus(path, eos="</s>"):
        data = [ ]
        with open(path) as fin:
            for line in fin:
                data += line.split() + [ eos ]
        return data
    

     来看一下这一段代码运行后产生的数据会是什么样子的

    data = [ ]
    eos="</s>"
    path = '/home/lai/下载/txt'
    with open(path) as fin:
        for line in fin:
            data += line.split() + [ eos ]
    print(data)   
    

     这里的txt文件如下

    no it was n't black monday 
     but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos 
     some circuit breakers installed after the october N crash failed their first test traders say unable to cool the selling panic in both stocks and futures 
    

     结果:

    ['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']
    

     输出的是单个单词组成的序列,每一行的结尾以</s>结尾

    2.

    class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
        def __init__(self, n_d, words, fix_emb=False):
            super(EmbeddingLayer, self).__init__()
            word2id = {}
            for w in words:
                if w not in word2id:
                    word2id[w] = len(word2id)#把文本映射到数字上。
    
            self.word2id = word2id
            self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size    
            self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量
    
        def forward(self, x):
            return self.embedding(x)
    
        def map_to_ids(self, text):#映射
            return np.asarray([self.word2id[x] for x in text],
                     dtype='int64'
            )
    

    我构造了一个可以运行的简易程序进行理解

    import numpy as np
    data = [ ("me gusta comer en la cafeteria".split(), "SPANISH"),
             ("Give it to me".split(), "ENGLISH"),
             ("No creo que sea una buena idea".split(), "SPANISH"),
             ("No it is not a good idea to get lost at sea".split(), "ENGLISH") ]
    
    test_data = [("Yo creo que si".split(), "SPANISH"),
                  ("it is lost on me".split(), "ENGLISH")]
    
    #将文字映射到数字
    word_to_ix = {}
    for sent, _ in data + test_data:
        for word in sent:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)
    print(word_to_ix)
    text={'creo': 10, 'idea': 15, 'a': 18}
    把一个句子sentence通过word_to_ix转换成数字化序列.
    print(np.asarray([word_to_ix[x] for x in text],
                     dtype='int64'))
    print(text)
    

    结果:

    {'Give': 6, 'lost': 21, 'No': 9, 'cafeteria': 5, 'comer': 2, 'en': 3, 'at': 22, 'not': 17, 'good': 19, 'to': 8, 'una': 13, 'Yo': 23, 'me': 0, 'a': 18, 'on': 25, 'creo': 10, 'get': 20, 'it': 7, 'idea': 15, 'buena': 14, 'is': 16, 'si': 24, 'que': 11, 'la': 4, 'gusta': 1, 'sea': 12}
    [15 10 18]
    {'idea': 15, 'creo': 10, 'a': 18}
    

     所以这一部分先将文字映射到数字,然后把一个句子sentence通过word_to_ix转换成数字化序列.

    关于读入数据的总结

    用代码中定义的类读入自己的数据

    import time
    import random
    import math
     
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
     
    def read_corpus(path, eos="</s>"):
        data = [ ]
        with open(path) as fin:
            for line in fin:
                data += line.split() + [ eos ]
        return data
     
    def create_batches(data_text, map_to_ids, batch_size):
        data_ids = map_to_ids(data_text)
        print(data_ids)
        N = len(data_ids)
        L = ((N-1) // batch_size) * batch_size
        x = np.copy(data_ids[:L].reshape(batch_size,-1).T)
        y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T)
        x, y = torch.from_numpy(x), torch.from_numpy(y)
        x, y = x.contiguous(), y.contiguous()
             
        return x,y
     
     
    class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
        def __init__(self, n_d, words, fix_emb=False):
            super(EmbeddingLayer, self).__init__()
            word2id = {}
            for w in words:
                if w not in word2id:
                    word2id[w] = len(word2id)#把文本映射到数字上。
     
            self.word2id = word2id
            self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size   
            self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量
     
        def forward(self, x):
            return self.embedding(x)
     
        def map_to_ids(self, text):#映射
            return np.asarray([self.word2id[x] for x in text],
                     dtype='int64'
            )
    train = read_corpus('/home/lai/下载/train.txt')
    print(train)
    model = EmbeddingLayer(10,train)
     
    print(model)
    map_to_ids = model.map_to_ids
    print(map_to_ids)
    train = create_batches(train, map_to_ids, batch_size=45)
    print(train)
    print(model.embedding.weight)
    

     结果

    ['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']
    EmbeddingLayer (
      (embedding): Embedding(59, 10)
    )
    <bound method EmbeddingLayer.map_to_ids of EmbeddingLayer (
      (embedding): Embedding(59, 10)
    )>
    [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14  3 15 16 17 18  9 19 20 21 22
     23 24 25 26 27  1 28  9 29 30  1 31 32 33 34 35 36 27 37  6 38 39 40 41 42
      9 43 24 44 45 46 47 48 49 50 51 33 52  9 53 54 28 55 56 57 58  6]
    (
    
    Columns 0 to 12 
        0     1     2     3     4     5     6     7     8     9    10    11    12
    
    Columns 13 to 25 
       13    14     3    15    16    17    18     9    19    20    21    22    23
    
    Columns 26 to 38 
       24    25    26    27     1    28     9    29    30     1    31    32    33
    
    Columns 39 to 44 
       34    35    36    27    37     6
    [torch.LongTensor of size 1x45]
    , 
    
    Columns 0 to 12 
        1     2     3     4     5     6     7     8     9    10    11    12    13
    
    Columns 13 to 25 
       14     3    15    16    17    18     9    19    20    21    22    23    24
    
    Columns 26 to 38 
       25    26    27     1    28     9    29    30     1    31    32    33    34
    
    Columns 39 to 44 
       35    36    27    37     6    38
    [torch.LongTensor of size 1x45]
    )
    Parameter containing:
     0.4376 -1.1509 -0.1407 -0.6956 -0.7292 -0.1944  0.8925  0.0688 -0.0560  2.5919
    -0.7855 -0.0448 -0.8069 -1.4774  0.2366  0.3967 -0.0706 -0.4602  1.0099 -0.0734
    -1.7748 -0.5265  0.4334 -0.7525 -0.0537  0.3966 -1.1800  0.2774 -2.2269 -0.4814
    -0.9325  1.7541  0.6094 -0.1564  0.8379 -0.4577 -1.3616 -2.1115 -0.7025 -0.6662
     1.0896 -0.1558 -1.1896 -0.0955 -2.7685  0.9485  1.1311 -1.1454 -0.4689  1.0410
     1.2227  1.8617  0.9243 -0.3036  0.2639 -0.6933 -0.4147 -0.4482  2.7447  0.0573
     1.0230  0.0484 -1.0139 -0.4291  0.6560  0.6911 -1.2519  0.9809  0.5843  0.2033
    -0.1128 -0.2149  1.2092  1.5636 -0.6737  1.0226  1.0155 -0.6230 -2.1714 -0.0226
     0.1947  1.0509  0.8694  1.5002 -0.3447 -0.2618  1.3267  0.0795  0.5041 -0.9763
     1.0146  0.9310 -1.2894  1.3288 -0.4146  0.1909 -0.3760  1.6011  0.7943  0.6290
    -0.2122 -1.4665  1.4775  0.5200  1.2882 -0.4101  0.4479  0.4447 -0.9597  1.7938
     0.8239  0.5278 -0.0036  0.8840  0.1069  0.2539 -0.7887  0.1271  0.8512  0.3766
    -0.5573  0.6985  1.0623 -1.3442  1.0792  0.4055  0.3625  1.7664 -0.3776  0.0266
    -0.2160  0.6872  1.6154 -0.5749  2.6781  1.1730 -0.9687 -1.2116 -0.9464  0.5248
     0.0916  0.3761 -1.0593 -0.6794  1.6780 -0.2040  0.8541 -0.0384  1.5180  0.6114
    -0.0321  0.5364  0.3896 -0.4864 -1.0080 -1.0698  0.1935  0.3896 -0.5745 -0.0273
     1.6301 -0.2652 -0.5325 -0.9380  0.3457 -2.0038 -0.0775 -0.7555 -0.8524 -0.9321
     0.0364 -0.4582 -0.3213 -0.9254 -1.0728 -0.1355  0.0993 -0.3186  2.3914 -1.5035
     0.0652  0.7371  0.9628  1.1530 -0.4044 -0.7131 -0.8299  1.6627 -0.8451 -1.0463
    -0.3744  0.6010 -2.4774  1.6569 -0.5589 -0.6512 -1.3728 -1.7573  1.1402  1.6838
     0.2883 -1.3225  1.2454  0.4222 -0.5544 -1.5851  1.7119  1.3759  1.2300 -0.0676
     0.6371  1.4258 -0.0222  1.2869  0.8767 -0.2959 -0.5973 -2.6143 -0.4366  0.9691
     0.3215  0.6463  0.4688  0.4125  0.1800  0.0441  0.0375  0.4195  1.5675  0.7011
     0.5407  1.4961 -1.5759 -1.7088 -0.5991  1.2169  0.9620 -1.7427 -0.0108 -0.3502
    -0.0906  0.1109 -0.4118  1.0876  0.8098 -0.8063 -0.2878  0.8896 -0.6304  0.0683
     0.6119  0.4786  0.6667  0.5702 -1.0531  0.4991  0.0538  1.1451 -0.7958 -0.0557
     1.3344  1.7192 -1.9320  2.1928 -0.1014  0.6543 -0.1026 -0.6506 -0.2592  0.0537
    -1.0320  1.9222 -0.6615  0.8046 -0.7667 -0.6775 -0.4904  0.6054  0.2837 -1.2075
     0.6694 -0.7456 -0.9112  0.0961  0.3517 -0.6020 -0.9233  0.8343  0.0364 -0.5247
    -1.4859 -0.8458  0.1642  0.2666 -2.9028  0.5945  0.0080  0.2036  1.9158  0.4553
     1.9948 -0.1500 -1.9221 -0.2734  0.7872  0.1108 -0.1790 -0.0549  0.8124  0.1027
    -0.8605  2.0634 -1.1081  0.3951  0.6214  0.1754  0.4764  0.9175 -0.3207 -0.3007
     0.3095  1.4426 -0.6971 -1.1740  0.7263  0.0415 -0.4804  0.2983  0.9156  0.6196
    -0.0862 -0.6351 -2.7732  1.2055  0.8422 -1.9189  1.4048 -0.8839  0.0811 -1.1528
    -0.5930  1.2625  0.5828 -0.8534  0.5789 -1.8812  1.2968  1.1347 -1.3243  0.5715
    -0.3339  0.5853  0.1010  1.2207  1.0524 -1.5834 -2.1429  0.7626  1.6698  0.7554
    -1.0038  1.6710 -0.6395 -0.3707  0.3491  0.0697  0.2043  0.2882  1.3192 -2.2766
     1.1236 -0.3770 -0.4992  0.3957 -1.0027  0.7676  1.3439  1.1695 -0.0786  0.0372
     0.1163 -0.4600 -1.2990 -0.6624  0.6378  0.4357 -0.2231  0.8826  0.7718  0.6312
    -0.9322  0.7925  1.0265 -0.9309  0.3586 -0.2663  0.7529 -0.8931  0.3230  1.0597
     0.0599  0.3668  0.2117 -0.3740 -1.2131 -0.7596 -0.1819  0.4357  3.0936  0.7486
    -0.7667 -0.3219 -0.3511 -0.6781  0.8756  1.2539  0.7989  0.6129  0.3743  0.6551
     0.8160 -0.3391 -0.4200  0.0984  0.0863 -1.1544  0.6204 -0.6724  0.2659  0.5388
     0.4748  0.5738 -0.8648  0.3691 -0.3480 -0.1510  0.8260  0.6924  0.0053 -0.6213
     0.2044  0.7698  0.7638  0.3532  0.7197  0.9445 -1.0761  0.0882  0.5684  0.4562
    -1.0330 -1.0507 -1.1679  0.0608  1.3512  0.2507  0.1740 -0.1574 -0.0552  0.6377
     1.3845  1.3252  2.5621 -0.5241  0.4334 -0.5092  0.1271 -1.3832  0.7112  0.1932
    -0.1659  0.2740 -0.6393 -0.2937 -0.2887 -0.7221 -1.1947 -1.0431  1.1029 -1.1171
    -0.2033 -0.5364 -0.4530 -2.4491 -1.2100 -1.5732  0.4191 -2.8109  0.3529 -0.7417
     0.1667 -0.0072  0.8795 -0.1538  0.5413  1.1036 -0.5249 -0.8432  0.0563 -0.2998
    -0.4226  0.6448 -0.4215  0.4342 -0.6593 -0.2078  1.4768  1.1829  0.8084 -2.0024
     2.1950  0.8189  0.4104  0.4159 -1.1775 -2.3510 -0.5108 -2.5914 -0.5550  0.7188
    -0.2978  0.1422 -0.0790 -1.6337 -0.4799 -0.9623 -0.9411  0.8321 -1.6386 -0.7785
    -0.3109  0.5793  0.5437  0.3324 -0.9796  1.4794  0.0364  0.6472  0.7203  1.5878
     0.0685  1.5637 -0.4545 -2.2541  0.5353  0.1305  1.3973 -1.2065 -0.5373  1.3352
     0.0670 -0.6708 -0.4448  0.1797 -0.6935  1.4199  0.2560  0.3542 -1.0556 -1.1745
    -0.3048  1.7749 -0.5777 -0.7029  0.9634 -0.9982  1.1929  1.5102  0.7618 -0.3569
     0.1294 -1.6825 -0.8473 -0.7886  0.3286 -0.2387 -0.4245 -0.3130  0.2273 -1.0860
    -0.7929 -1.0838  0.1994 -0.4874  0.6568  0.1065  1.8086  0.2142 -1.1657 -0.2313
    [torch.FloatTensor of size 59x10]
    

     我把这个过程的中间结果全都打印出来,便于理解,对于model.embedding.weight,这个embedding层的weight应该是指每个单词所对应的向量

    3.

    def init_weights(self):        
    val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) else: p.data.zero_()

     p.data.uniform_(-val_range, val_range)和p.data.zero_()

    我自己构造了一个模型用以探究其功能

    import time
    import random
    import math
     
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
     
    def read_corpus(path, eos="</s>"):
        data = [ ]
        with open(path) as fin:
            for line in fin:
                data += line.split() + [ eos ]
        return data
     
    def create_batches(data_text, map_to_ids, batch_size):
        data_ids = map_to_ids(data_text)
        print(data_ids)
        N = len(data_ids)
        L = ((N-1) // batch_size) * batch_size
        x = np.copy(data_ids[:L].reshape(batch_size,-1).T)
        y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T)
        x, y = torch.from_numpy(x), torch.from_numpy(y)
        x, y = x.contiguous(), y.contiguous()
             
        return x,y
     
     
    class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量
        def __init__(self, n_d, words, fix_emb=False):
            super(EmbeddingLayer, self).__init__()
            word2id = {}
            for w in words:
                if w not in word2id:
                    word2id[w] = len(word2id)#把文本映射到数字上。
     
            self.word2id = word2id
            self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size   
            self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量
     
        def forward(self, x):
            return self.embedding(x)
     
        def map_to_ids(self, text):#映射
            return np.asarray([self.word2id[x] for x in text],
                     dtype='int64'
            )
    train = read_corpus('/home/lai/下载/train.txt')
    print(train)
    model = EmbeddingLayer(10,train)
    for param in model.parameters():
              print(param.data.uniform_(0,2))
              print(param.data)
    

     结果:

    ['no', 'it', 'was', "n't", 'black', 'monday', '</s>', 'but', 'while', 'the', 'new', 'york', 'stock', 'exchange', 'did', "n't", 'fall', 'apart', 'friday', 'as', 'the', 'dow', 'jones', 'industrial', 'average', 'plunged', 'N', 'points', 'most', 'of', 'it', 'in', 'the', 'final', 'hour', 'it', 'barely', 'managed', 'to', 'stay', 'this', 'side', 'of', 'chaos', '</s>', 'some', 'circuit', 'breakers', 'installed', 'after', 'the', 'october', 'N', 'crash', 'failed', 'their', 'first', 'test', 'traders', 'say', 'unable', 'to', 'cool', 'the', 'selling', 'panic', 'in', 'both', 'stocks', 'and', 'futures', '</s>']
    
     1.4317  0.6596  0.0516  1.0376  0.1926  1.2600  0.0494  0.8796  1.9962  1.2159
     0.2419  0.6704  0.1465  1.6639  1.5062  1.6871  0.7300  1.6097  0.6998  1.1892
     0.8882  0.7436  0.7304  0.6540  1.0289  0.7935  1.9055  1.5515  1.2066  1.7531
     1.1168  1.8315  0.7545  1.8267  0.9284  0.4486  0.5175  0.0532  0.8085  1.3437
     0.2860  0.2907  0.8077  1.9553  1.2979  1.1078  0.0623  1.8027  1.8158  0.0852
     1.0238  0.3384  0.5703  1.5060  1.0183  0.2247  0.2230  0.7064  0.3984  1.6884
     1.1680  1.5321  0.9316  1.9031  0.5216  0.8028  0.8465  0.5166  1.5459  0.2865
     0.6001  1.1145  1.6196  1.7692  1.7195  1.3123  0.4399  0.4006  1.2029  1.6420
     1.9466  1.9689  0.8811  0.2398  1.3328  0.5307  1.6048  0.9328  1.6946  0.5598
     1.9595  0.3396  1.4121  0.1757  0.3677  0.5584  1.9388  1.2118  1.3966  1.4618
     1.2004  0.8745  0.4966  1.5487  0.7805  1.0708  1.8857  0.1973  1.1339  1.0490
     0.4731  0.2265  1.0293  0.7514  1.3949  1.5742  0.0032  1.0001  1.6449  1.4519
     0.2014  0.0456  1.2669  1.2988  0.9432  1.0757  0.6428  1.3084  0.7477  0.3753
     0.1086  0.1842  1.3811  1.4472  0.6998  0.0028  1.8839  1.0238  1.6243  1.3262
     0.6383  1.4817  0.2363  1.7802  1.2998  1.8367  1.9967  0.5028  0.0819  1.4886
     0.2979  0.3566  0.5144  0.6787  0.8583  0.9256  0.8171  0.0482  0.6638  1.3788
     0.4180  1.5806  1.0489  0.6587  1.6041  1.0644  1.9635  1.4030  1.5242  1.9292
     1.7177  1.0168  1.4879  1.5941  0.6318  0.4966  1.9573  1.0276  1.8955  0.9595
     1.3229  0.5519  0.0796  1.0840  0.2204  0.7510  0.6440  0.7307  1.0064  1.0647
     0.5325  1.1621  1.0669  1.2276  0.2488  1.6607  1.6797  1.7445  0.7051  0.0290
     1.9457  0.8071  1.9667  1.5591  1.6706  1.8955  0.2541  1.2218  0.5843  1.8493
     0.8763  0.2127  0.5883  0.9636  1.9839  0.5030  0.8972  0.3293  1.1231  0.8687
     1.3803  0.9248  1.3445  0.1882  1.3226  1.9621  1.0377  1.7566  1.6686  1.6855
     1.9552  0.1764  0.6670  1.5401  0.4913  0.8954  0.3951  0.8991  1.5485  0.6603
     0.5025  1.1702  1.8270  0.9304  0.4637  1.4306  0.5506  0.3712  0.0122  0.4379
     0.2657  0.0599  1.8354  0.2358  1.7581  0.3380  0.9558  1.7275  0.5202  1.3801
     0.7791  1.4060  0.6530  1.8742  0.5895  0.7742  1.7748  1.7141  1.2038  0.2918
     1.0312  1.9371  0.8345  0.4569  0.0447  0.2415  1.3479  0.9809  0.0566  1.0656
     0.3313  0.4801  0.3357  1.4143  0.6487  0.7692  1.0398  1.1538  0.8307  0.8231
     1.4774  0.1299  1.1836  0.2659  1.4413  0.4059  0.2428  1.0973  0.5491  0.2169
     1.8733  0.7073  0.6730  1.7413  1.1705  1.7082  1.0175  1.2589  1.9080  0.7648
     1.0761  1.1880  1.5441  1.9458  0.5513  1.5324  1.3756  0.3201  1.6600  0.7143
     1.8071  1.2422  1.5758  1.5677  1.5796  1.0328  0.3856  0.3648  0.5017  1.2543
     1.8749  1.9269  0.2120  0.3971  0.4451  0.7651  0.6793  0.1512  1.7845  0.1911
     1.2950  0.9356  1.0757  0.7603  0.6917  0.2891  1.3327  1.1102  0.3153  1.7074
     0.9031  1.8973  1.6392  0.3516  0.4412  1.4444  1.4032  0.1110  1.1379  0.2283
     0.4678  1.3409  0.6576  0.5351  1.2108  1.7777  0.5716  1.9060  1.4147  1.4487
     0.9546  0.9840  0.3020  1.7696  0.9677  1.1206  1.5639  0.0437  0.1485  0.1437
     1.0374  0.8910  1.7921  1.1207  0.4798  0.5863  0.0112  0.7735  0.8233  0.8936
     1.1980  1.6834  0.5779  0.7173  1.5803  1.6196  0.1642  1.6706  1.9906  1.4089
     0.2140  0.6833  1.6710  0.4645  0.0886  1.6945  0.8467  1.3290  1.7448  0.5405
     1.2914  1.5487  0.8509  1.8434  1.3398  0.3215  0.5732  1.5421  1.5103  0.2807
     1.4965  0.5448  1.0851  0.6836  1.4491  0.4040  1.8560  1.2288  1.4055  0.7298
     0.6319  0.9501  0.5320  1.2168  0.0031  1.8810  1.5128  0.4442  1.3887  1.5603
     0.5936  1.9980  1.4988  0.5884  1.9388  1.8275  0.1833  1.3767  1.2934  0.6319
     0.2711  0.0854  0.7103  0.8877  1.9997  0.2341  0.7163  1.8445  1.4777  0.0532
     1.1966  1.1512  1.8602  0.0552  1.7778  0.4180  1.0675  1.0646  1.6946  1.9979
     1.4076  0.1683  0.6894  1.0616  1.8683  0.3648  0.9496  0.4799  1.5983  0.8257
     1.5951  0.7438  0.4807  1.7440  1.1139  1.5855  0.3561  0.5960  0.6389  1.7573
     1.3262  1.5965  0.1100  1.0414  0.1697  1.8125  0.8135  0.1712  0.8863  0.5336
     0.4490  0.1233  0.0136  1.3416  0.2668  0.2091  0.8900  0.3823  1.3197  1.4936
     1.3607  0.6022  0.9031  0.7420  0.5538  1.5407  1.1918  0.5104  1.7564  0.1658
     0.4650  0.4523  1.3443  1.5691  1.0239  0.5898  0.8882  0.1892  1.0721  1.6908
     1.0479  1.9074  0.3732  1.8763  1.5337  0.2918  1.9343  1.6055  0.0709  0.9326
     0.6884  1.6136  1.1970  1.0819  0.3358  0.0234  0.4381  1.2239  1.1829  1.1254
     1.4076  0.4704  0.1724  0.5579  0.1318  0.5537  0.2435  0.8490  0.7200  1.5814
     0.2753  0.4727  0.5446  1.7038  0.8742  1.2662  1.3187  0.5939  1.2068  0.3514
     0.6184  1.6217  1.0503  1.0958  1.9824  0.6737  0.3009  0.7889  1.8378  1.7559
     0.6418  1.8355  0.7340  0.7232  0.6433  0.0288  1.3672  0.6466  0.3574  1.0760
    [torch.FloatTensor of size 59x10]
    
    
     1.4317  0.6596  0.0516  1.0376  0.1926  1.2600  0.0494  0.8796  1.9962  1.2159
     0.2419  0.6704  0.1465  1.6639  1.5062  1.6871  0.7300  1.6097  0.6998  1.1892
     0.8882  0.7436  0.7304  0.6540  1.0289  0.7935  1.9055  1.5515  1.2066  1.7531
     1.1168  1.8315  0.7545  1.8267  0.9284  0.4486  0.5175  0.0532  0.8085  1.3437
     0.2860  0.2907  0.8077  1.9553  1.2979  1.1078  0.0623  1.8027  1.8158  0.0852
     1.0238  0.3384  0.5703  1.5060  1.0183  0.2247  0.2230  0.7064  0.3984  1.6884
     1.1680  1.5321  0.9316  1.9031  0.5216  0.8028  0.8465  0.5166  1.5459  0.2865
     0.6001  1.1145  1.6196  1.7692  1.7195  1.3123  0.4399  0.4006  1.2029  1.6420
     1.9466  1.9689  0.8811  0.2398  1.3328  0.5307  1.6048  0.9328  1.6946  0.5598
     1.9595  0.3396  1.4121  0.1757  0.3677  0.5584  1.9388  1.2118  1.3966  1.4618
     1.2004  0.8745  0.4966  1.5487  0.7805  1.0708  1.8857  0.1973  1.1339  1.0490
     0.4731  0.2265  1.0293  0.7514  1.3949  1.5742  0.0032  1.0001  1.6449  1.4519
     0.2014  0.0456  1.2669  1.2988  0.9432  1.0757  0.6428  1.3084  0.7477  0.3753
     0.1086  0.1842  1.3811  1.4472  0.6998  0.0028  1.8839  1.0238  1.6243  1.3262
     0.6383  1.4817  0.2363  1.7802  1.2998  1.8367  1.9967  0.5028  0.0819  1.4886
     0.2979  0.3566  0.5144  0.6787  0.8583  0.9256  0.8171  0.0482  0.6638  1.3788
     0.4180  1.5806  1.0489  0.6587  1.6041  1.0644  1.9635  1.4030  1.5242  1.9292
     1.7177  1.0168  1.4879  1.5941  0.6318  0.4966  1.9573  1.0276  1.8955  0.9595
     1.3229  0.5519  0.0796  1.0840  0.2204  0.7510  0.6440  0.7307  1.0064  1.0647
     0.5325  1.1621  1.0669  1.2276  0.2488  1.6607  1.6797  1.7445  0.7051  0.0290
     1.9457  0.8071  1.9667  1.5591  1.6706  1.8955  0.2541  1.2218  0.5843  1.8493
     0.8763  0.2127  0.5883  0.9636  1.9839  0.5030  0.8972  0.3293  1.1231  0.8687
     1.3803  0.9248  1.3445  0.1882  1.3226  1.9621  1.0377  1.7566  1.6686  1.6855
     1.9552  0.1764  0.6670  1.5401  0.4913  0.8954  0.3951  0.8991  1.5485  0.6603
     0.5025  1.1702  1.8270  0.9304  0.4637  1.4306  0.5506  0.3712  0.0122  0.4379
     0.2657  0.0599  1.8354  0.2358  1.7581  0.3380  0.9558  1.7275  0.5202  1.3801
     0.7791  1.4060  0.6530  1.8742  0.5895  0.7742  1.7748  1.7141  1.2038  0.2918
     1.0312  1.9371  0.8345  0.4569  0.0447  0.2415  1.3479  0.9809  0.0566  1.0656
     0.3313  0.4801  0.3357  1.4143  0.6487  0.7692  1.0398  1.1538  0.8307  0.8231
     1.4774  0.1299  1.1836  0.2659  1.4413  0.4059  0.2428  1.0973  0.5491  0.2169
     1.8733  0.7073  0.6730  1.7413  1.1705  1.7082  1.0175  1.2589  1.9080  0.7648
     1.0761  1.1880  1.5441  1.9458  0.5513  1.5324  1.3756  0.3201  1.6600  0.7143
     1.8071  1.2422  1.5758  1.5677  1.5796  1.0328  0.3856  0.3648  0.5017  1.2543
     1.8749  1.9269  0.2120  0.3971  0.4451  0.7651  0.6793  0.1512  1.7845  0.1911
     1.2950  0.9356  1.0757  0.7603  0.6917  0.2891  1.3327  1.1102  0.3153  1.7074
     0.9031  1.8973  1.6392  0.3516  0.4412  1.4444  1.4032  0.1110  1.1379  0.2283
     0.4678  1.3409  0.6576  0.5351  1.2108  1.7777  0.5716  1.9060  1.4147  1.4487
     0.9546  0.9840  0.3020  1.7696  0.9677  1.1206  1.5639  0.0437  0.1485  0.1437
     1.0374  0.8910  1.7921  1.1207  0.4798  0.5863  0.0112  0.7735  0.8233  0.8936
     1.1980  1.6834  0.5779  0.7173  1.5803  1.6196  0.1642  1.6706  1.9906  1.4089
     0.2140  0.6833  1.6710  0.4645  0.0886  1.6945  0.8467  1.3290  1.7448  0.5405
     1.2914  1.5487  0.8509  1.8434  1.3398  0.3215  0.5732  1.5421  1.5103  0.2807
     1.4965  0.5448  1.0851  0.6836  1.4491  0.4040  1.8560  1.2288  1.4055  0.7298
     0.6319  0.9501  0.5320  1.2168  0.0031  1.8810  1.5128  0.4442  1.3887  1.5603
     0.5936  1.9980  1.4988  0.5884  1.9388  1.8275  0.1833  1.3767  1.2934  0.6319
     0.2711  0.0854  0.7103  0.8877  1.9997  0.2341  0.7163  1.8445  1.4777  0.0532
     1.1966  1.1512  1.8602  0.0552  1.7778  0.4180  1.0675  1.0646  1.6946  1.9979
     1.4076  0.1683  0.6894  1.0616  1.8683  0.3648  0.9496  0.4799  1.5983  0.8257
     1.5951  0.7438  0.4807  1.7440  1.1139  1.5855  0.3561  0.5960  0.6389  1.7573
     1.3262  1.5965  0.1100  1.0414  0.1697  1.8125  0.8135  0.1712  0.8863  0.5336
     0.4490  0.1233  0.0136  1.3416  0.2668  0.2091  0.8900  0.3823  1.3197  1.4936
     1.3607  0.6022  0.9031  0.7420  0.5538  1.5407  1.1918  0.5104  1.7564  0.1658
     0.4650  0.4523  1.3443  1.5691  1.0239  0.5898  0.8882  0.1892  1.0721  1.6908
     1.0479  1.9074  0.3732  1.8763  1.5337  0.2918  1.9343  1.6055  0.0709  0.9326
     0.6884  1.6136  1.1970  1.0819  0.3358  0.0234  0.4381  1.2239  1.1829  1.1254
     1.4076  0.4704  0.1724  0.5579  0.1318  0.5537  0.2435  0.8490  0.7200  1.5814
     0.2753  0.4727  0.5446  1.7038  0.8742  1.2662  1.3187  0.5939  1.2068  0.3514
     0.6184  1.6217  1.0503  1.0958  1.9824  0.6737  0.3009  0.7889  1.8378  1.7559
     0.6418  1.8355  0.7340  0.7232  0.6433  0.0288  1.3672  0.6466  0.3574  1.0760
    [torch.FloatTensor of size 59x10]
    

     param.data.uniform_(-1,1)改变则得到的tensor里面的值随之改变,model.parameter()生成的是基于模型参数的迭代器

    在这里记录一个我刚观察到的知识,param.dim()输出tensor的维度信息,维度与torch.FloatTensor of size 5x1x2x2有关,size为5x1x2x2是4维,size为5x1x2是3维以此类推,而Conv2d的这些size是由(Conv2d的前两个参数分别代表input image channel, output channel)输入图像的维度(RGB为3,灰度图像是1),输出的图像的维度(即filter的个数),还有kernel_size决定的。

    而输出结果中的维度信息为1的tendor,是卷积得到的结果

    4、

    def init_hidden(self, batch_size):
            weight = next(self.parameters()).data
            zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_())
            if self.args.lstm:
                return (zeros, zeros)
            else:
                return zeros
    

     关于weight = next(self.parameters()).data

    看看基于上面那个模型得到的结果

    import torch.nn as nn
    import torch.nn.functional as F
     
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 2,2)
            self.conv2 = nn.Conv2d(1, 5, 2,1)
     
        def forward(self, x):
           x = F.relu(self.conv1(x))
           return F.relu(self.conv2(x))
           
    model=Model()
    print(model)
    print(('next'))
    x = next(model.parameters()).data
    print(x)
    

     结果

    Model (
      (conv1): Conv2d(1, 6, kernel_size=(2, 2), stride=(2, 2))
      (conv2): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
    )
    next
    
    (0 ,0 ,.,.) = 
      0.2855 -0.0303
      0.1428 -0.4025
    
    (1 ,0 ,.,.) = 
     -0.0901  0.2736
     -0.1527 -0.2854
    
    (2 ,0 ,.,.) = 
      0.2193 -0.3886
     -0.4652  0.2307
    
    (3 ,0 ,.,.) = 
      0.1918  0.4587
     -0.0480 -0.0636
    
    (4 ,0 ,.,.) = 
      0.4017 -0.4123
      0.3016 -0.2714
    
    (5 ,0 ,.,.) = 
      0.2053  0.1252
     -0.2365 -0.3651
    [torch.FloatTensor of size 6x1x2x2]
    

     输出的是模型参数中的第0个模型参数的数据。

  • 相关阅读:
    走向灵活软件之路——面向对象的六大原则
    StartUML破解
    非常实用的Android Studio快捷键
    Android Studio更新失败
    《Effect Java》学习笔记1———创建和销毁对象
    使用spring单元调试出错initializationError
    Spring注入的不同方式
    DNS域名解析的过程
    浏览器的缓存机制
    Http建立连接的方式
  • 原文地址:https://www.cnblogs.com/lindaxin/p/7978765.html
Copyright © 2020-2023  润新知