class Model(nn.Module): def __init__(self, words, args): super(Model, self).__init__() self.args = args self.n_d = args.d self.depth = args.depth self.drop = nn.Dropout(args.dropout) self.embedding_layer = EmbeddingLayer(self.n_d, words) self.n_V = self.embedding_layer.n_V if args.lstm: self.rnn = nn.LSTM(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout ) else: self.rnn = MF.SRU(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout, rnn_dropout = args.rnn_dropout, use_tanh = 0 ) self.output_layer = nn.Linear(self.n_d, self.n_V) # tie weights self.output_layer.weight = self.embedding_layer.embedding.weight#我运行了一下应该是指每个单词所对应的向量 self.init_weights() if not args.lstm: self.rnn.set_bias(args.bias) def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) else: p.data.zero_() def forward(self, x, hidden): emb = self.drop(self.embedding_layer(x)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) output = output.view(-1, output.size(2)) output = self.output_layer(output) return output, hidden def init_hidden(self, batch_size):#hidden层的0初始化 weight = next(self.parameters()).data zeros = Variable(weight.new(self.depth, batch_size, self.n_d).zero_()) if self.args.lstm: return (zeros, zeros) else: return zeros def print_pnorm(self):#p范数 norms = [ "{:.0f}".format(x.norm().data[0]) for x in self.parameters() ] sys.stdout.write(" p_norm: {} ".format( norms ))
这个问题源于我对Model类中的方法init_weight的理解,一直读不懂这个方法是做什么的,即self.parameters(),这个迭代器送出来的参数是什么呢,我假设这个里面应该是每一层更新的权重,所以我将sru源码的一部分给取了出来,让其输出Model里的parameters,代码如下(sru源码--language model):
#coding:UTF-8 ''' Created on 2017-12-4 @author: lai ''' import time import random import math import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import sys import cuda_functional as MF def read_corpus(path, eos="</s>"): data = [ ] with open(path) as fin: for line in fin: data += line.split() + [ eos ] return data def create_batches(data_text, map_to_ids, batch_size): data_ids = map_to_ids(data_text) N = len(data_ids) L = ((N-1) // batch_size) * batch_size x = np.copy(data_ids[:L].reshape(batch_size,-1).T) y = np.copy(data_ids[1:L+1].reshape(batch_size,-1).T) x, y = torch.from_numpy(x), torch.from_numpy(y) x, y = x.contiguous(), y.contiguous() return x,y class EmbeddingLayer(nn.Module):#为语料中每一个单词对应的其相应的词向量 def __init__(self, n_d, words, fix_emb=False): super(EmbeddingLayer, self).__init__() word2id = {} for w in words: if w not in word2id: word2id[w] = len(word2id)#把文本映射到数字上。 self.word2id = word2id self.n_V, self.n_d = len(word2id), n_d#n_V应该是指词库大小,n_d指hidden state size self.embedding = nn.Embedding(self.n_V, n_d)#赋予每个单词相应的词向量 def forward(self, x): return self.embedding(x) def map_to_ids(self, text):#映射 return np.asarray([self.word2id[x] for x in text], dtype='int64' ) class Model(nn.Module): def __init__(self, words, args): super(Model, self).__init__() self.args = args self.n_d = args.d self.depth = args.depth self.drop = nn.Dropout(args.dropout) self.embedding_layer = EmbeddingLayer(self.n_d, words) self.n_V = self.embedding_layer.n_V if args.lstm: self.rnn = nn.LSTM(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout ) else: self.rnn = MF.SRU(self.n_d, self.n_d, self.depth, dropout = args.rnn_dropout, rnn_dropout = args.rnn_dropout, use_tanh = 0 ) self.output_layer = nn.Linear(self.n_d, self.n_V) # tie weights self.output_layer.weight = self.embedding_layer.embedding.weight#我运行了一下应该是指每个单词所对应的向量 self.init_weights() if not args.lstm: self.rnn.set_bias(args.bias) def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) print('222222',p.data) else: p.data.zero_() print('0000',p.data) if __name__ == "__main__": argparser = argparse.ArgumentParser(sys.argv[0], conflict_handler='resolve') argparser.add_argument("--lstm", action="store_true") argparser.add_argument("--train", type=str, required=True, help="training file") argparser.add_argument("--batch_size", "--batch", type=int, default=32) argparser.add_argument("--unroll_size", type=int, default=35) argparser.add_argument("--max_epoch", type=int, default=300) argparser.add_argument("--d", type=int, default=910) argparser.add_argument("--dropout", type=float, default=0.7, help="dropout of word embeddings and softmax output" ) argparser.add_argument("--rnn_dropout", type=float, default=0.2, help="dropout of RNN layers" ) argparser.add_argument("--bias", type=float, default=-3, help="intial bias of highway gates", ) argparser.add_argument("--depth", type=int, default=6) argparser.add_argument("--lr", type=float, default=1.0) argparser.add_argument("--lr_decay", type=float, default=0.98) argparser.add_argument("--lr_decay_epoch", type=int, default=175) argparser.add_argument("--weight_decay", type=float, default=1e-5) argparser.add_argument("--clip_grad", type=float, default=5) args = argparser.parse_args() print(args) train = read_corpus(args.train) model = Model(train, args) model.cuda() map_to_ids = model.embedding_layer.map_to_ids train = create_batches(train, map_to_ids, args.batch_size) print('111',model.parameters())
再终端中输入运行命令:
python 2.py --train train.txt
输出:
Namespace(batch_size=32, bias=-3, clip_grad=5, d=910, depth=6, dropout=0.7, lr=1.0, lr_decay=0.98, lr_decay_epoch=175, lstm=False, max_epoch=300, rnn_dropout=0.2, train='train.txt', unroll_size=35, weight_decay=1e-05) 222222 4.8794e-02 5.0702e-02 -3.2630e-02 ... -5.3750e-02 4.2253e-02 1.6446e-02 -5.1652e-02 -2.3051e-02 4.3890e-02 ... 1.8805e-02 1.6605e-02 2.6666e-02 2.5273e-02 -5.1426e-03 5.3130e-02 ... -4.8786e-02 4.0186e-02 -4.3724e-02 ... ⋱ ... -3.3133e-02 3.3400e-02 3.2185e-02 ... -5.0593e-02 -2.3048e-02 -2.1572e-02 2.9908e-03 -2.1938e-02 -2.1926e-02 ... -4.5163e-02 -4.1678e-02 -5.2639e-02 -2.2036e-02 2.3908e-04 1.9383e-02 ... -1.0341e-02 4.7491e-02 -5.0599e-02 [torch.FloatTensor of size 10000x910] 222222 -6.1627e-03 1.9962e-02 5.6098e-02 ... 5.2324e-02 -1.0912e-02 1.7969e-02 1.1683e-02 1.4485e-02 3.7155e-02 ... -4.6458e-02 -2.8750e-02 -1.7442e-02 5.3697e-02 3.4534e-02 -2.5292e-02 ... -3.9264e-02 -2.8864e-02 2.3790e-02 ... ⋱ ... 7.6450e-03 -2.1589e-02 -7.6684e-03 ... -5.6521e-02 -5.5103e-02 -3.8065e-02 4.7252e-02 5.7209e-02 -4.9279e-02 ... -2.0944e-02 -4.3891e-03 1.8820e-02 2.7026e-02 3.5590e-02 1.3660e-02 ... -1.6219e-02 -2.1856e-02 3.2678e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 222222 -1.2439e-02 -5.5866e-02 -3.5799e-02 ... -4.9976e-02 7.3134e-03 4.5684e-03 -4.6130e-02 -4.7773e-02 -4.3640e-02 ... -3.2027e-02 -8.8562e-03 4.3218e-02 -3.5260e-02 3.1456e-02 1.3324e-02 ... 3.4487e-02 -7.7102e-03 2.9963e-02 ... ⋱ ... -1.6921e-02 -1.5771e-02 5.3847e-02 ... 4.6351e-02 4.9333e-02 -1.1978e-02 -1.8770e-02 -1.5817e-02 -7.6655e-05 ... -8.4615e-03 1.4490e-02 -5.6743e-02 4.1060e-03 -2.4452e-02 2.5512e-02 ... -2.3961e-02 -5.2609e-02 3.3445e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 222222 -3.6535e-02 -2.4697e-02 3.2514e-02 ... 3.0889e-02 -4.7916e-03 9.5873e-03 4.5222e-02 -5.7333e-02 5.4079e-02 ... 1.7790e-02 3.5510e-02 -1.2171e-02 7.5279e-03 -2.7133e-02 -5.1036e-02 ... 5.6305e-02 -2.0042e-02 -2.8884e-02 ... ⋱ ... -4.5409e-02 -1.6207e-02 3.4128e-02 ... -5.6980e-02 1.6646e-02 -2.0662e-02 2.8941e-02 3.1405e-02 5.7100e-02 ... 3.9499e-03 9.5197e-03 -2.3475e-02 -5.1939e-02 -9.6567e-03 3.1139e-02 ... -1.0642e-02 -4.8837e-02 2.7009e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 222222 1.4545e-02 -1.7484e-02 -1.3450e-02 ... 4.9990e-02 3.6013e-03 -2.5272e-02 4.6915e-02 2.4484e-02 -2.6583e-02 ... 3.4737e-02 3.9499e-02 -2.8632e-02 1.8722e-02 -2.1864e-02 2.4649e-02 ... 4.9049e-02 4.8219e-02 3.7317e-02 ... ⋱ ... -2.6708e-02 4.2176e-02 3.8287e-02 ... 3.3608e-02 -2.7229e-02 9.4752e-03 1.2404e-02 1.7356e-02 7.0494e-03 ... 1.5802e-02 -7.5168e-03 -4.1576e-02 -3.1050e-02 3.5632e-02 2.2318e-03 ... -1.9828e-02 4.4247e-02 -2.3669e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 222222 -8.6860e-03 2.4917e-02 -4.8584e-02 ... -1.1277e-02 -1.2668e-02 -1.6445e-02 -2.5161e-02 -4.4705e-03 -4.5265e-02 ... -3.1264e-02 -4.2164e-02 -2.4916e-02 -1.8575e-02 -1.8767e-02 -5.2647e-02 ... 5.4461e-02 -5.0726e-02 -3.1518e-03 ... ⋱ ... -3.1745e-02 -3.8159e-02 1.7577e-02 ... -5.6739e-02 1.9196e-02 1.6574e-02 -5.5951e-02 -6.2410e-03 -5.6714e-02 ... 2.8419e-02 5.7141e-02 2.3431e-02 -1.7646e-02 8.7587e-04 -2.3462e-02 ... -4.9807e-04 4.2565e-02 -4.5738e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 222222 -8.5008e-03 4.9589e-02 4.8005e-02 ... 5.2643e-03 1.4385e-02 -1.8161e-02 3.0520e-03 5.5756e-02 3.9487e-02 ... -2.9614e-03 -5.1740e-02 -4.8080e-02 1.8335e-02 -5.5416e-02 -1.0836e-02 ... 2.8635e-02 -8.8250e-03 -1.4533e-02 ... ⋱ ... 5.2809e-02 -3.2417e-02 3.9305e-02 ... 2.2464e-02 -4.7438e-02 5.1094e-02 -5.5829e-02 -4.9564e-02 1.3892e-02 ... -3.4778e-02 4.3359e-02 8.6556e-03 -2.1687e-03 -3.7360e-03 4.2217e-03 ... 3.9019e-02 -4.2598e-02 1.6985e-02 [torch.FloatTensor of size 910x2730] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 1820] 0000 0 0 0 ⋮ 0 0 0 [torch.FloatTensor of size 10000] 111 <generator object Module.parameters at 0x7f6fe8cc3eb8>
下面是方法init_weight的代码:
def init_weights(self): val_range = (3.0/self.n_d)**0.5 for p in self.parameters(): if p.dim() > 1: # matrix p.data.uniform_(-val_range, val_range) print('222222',p.data) else: p.data.zero_() print('0000',p.data)
上面运行输出的结果就是p.data.uniform_(-val_range, val_range)以及p.data.zero_()的值,这里的参数我猜测一个是sru中的权重(w)另一个是偏置(b),但是这样的话就有一个疑问,这里输出的第一个大小为10000*910的tensor是词向量化得到的10000个单词的词向量,而最后一个大小为10000的tensor是最后线性分类全连接层的参数,所以剩下有六对的w和b,但是这样的话就有一个疑问,因为循环神经网络是时间共享的,所以应该只有一对才对,为了解决这个疑问,
我将用lstm做mnist分类的代码拿了出来,并将它的model的参数打印了出来,代码和结果如下所示
代码:
import torch from torch import nn from torch.autograd import Variable import torchvision.datasets as dsets import torchvision.transforms as transforms import matplotlib.pyplot as plt torch.manual_seed(1) # reproducible # Hyper Parameters EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch BATCH_SIZE = 64 TIME_STEP = 28 # rnn time step / image height INPUT_SIZE = 28 # rnn input size / image width LR = 0.01 # learning rate DOWNLOAD_MNIST = True # set to True if haven't download the data # Mnist digital dataset train_data = dsets.MNIST( root='./mnist/', train=True, # this is training data transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, # download it if you don't have it ) # plot one example print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_labels.size()) # (60000) plt.imshow(train_data.train_data[0].numpy(), cmap='gray') plt.title('%i' % train_data.train_labels[0]) plt.show() # Data Loader for easy mini-batch return in training train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # convert test data into Variable, pick 2000 samples to speed up testing test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor()) test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1) test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array class RNN(nn.Module): def __init__(self): super(RNN, self).__init__() self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns input_size=INPUT_SIZE, hidden_size=64, # rnn hidden unit num_layers=2, # number of rnn layer batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size) ) self.out = nn.Linear(64, 10) def forward(self, x): # x shape (batch, time_step, input_size) # r_out shape (batch, time_step, output_size) # h_n shape (n_layers, batch, hidden_size) # h_c shape (n_layers, batch, hidden_size) r_out, (h_n, h_c) = self.rnn(x, None) # None represents zero initial hidden state # choose r_out at the last time step out = self.out(r_out[:, -1, :]) return out def init_weights(self): for p in self.parameters(): print('PPP',p.data) rnn = RNN() print(rnn.init_weights())
输出:
torch.Size([60000, 28, 28]) torch.Size([60000]) PPP -2.0745e-02 1.2430e-01 5.5081e-02 ... -1.4137e-02 9.4529e-02 -6.7606e-02 -1.1815e-01 8.6035e-03 4.2617e-02 ... 8.2401e-02 -1.1524e-01 -5.6738e-02 -8.2542e-02 -1.1019e-01 9.4536e-02 ... 4.0159e-02 6.2041e-02 -5.0376e-02 ... ⋱ ... 1.0238e-01 5.3194e-02 5.3342e-02 ... -1.5019e-02 -1.0299e-01 2.3091e-02 4.5909e-02 -5.0352e-02 -2.5497e-02 ... 1.1765e-01 -1.1448e-01 -3.1609e-02 3.1011e-06 -1.0142e-01 1.2229e-01 ... 3.1813e-02 7.6921e-02 4.4233e-03 [torch.FloatTensor of size 256x28] PPP -2.4325e-03 1.1478e-02 9.3458e-02 ... -1.1657e-01 -3.6968e-03 1.2013e-01 1.2265e-01 -2.3560e-02 -5.3951e-02 ... 4.1457e-02 -6.7170e-02 6.1414e-02 1.2334e-01 -6.3188e-02 3.9050e-02 ... 8.4631e-02 4.0930e-04 8.3604e-03 ... ⋱ ... 5.6417e-02 3.7298e-02 5.7616e-02 ... 2.9125e-02 -6.6484e-02 -4.2838e-02 -6.0267e-02 8.6004e-02 4.4727e-02 ... -4.9643e-02 -3.5065e-03 -2.5401e-02 8.1001e-02 5.8518e-02 -9.0292e-02 ... -1.5258e-02 5.6519e-02 6.1370e-02 [torch.FloatTensor of size 256x64] PPP 0.0282 -0.0362 0.0864 0.0677 0.0012 0.0699 0.0850 -0.0927 0.0074 -0.0183 0.0679 0.1177 0.0255 0.1012 0.1248 -0.0625 0.0023 -0.0255 0.0870 -0.0900 0.1057 0.1233 0.0982 0.0475 -0.0387 -0.0267 -0.0964 -0.0153 0.0004 -0.0410 0.0771 -0.0399 0.0746 -0.0210 -0.0396 0.1108 0.0347 0.0263 0.0244 0.1113 -0.1071 0.1036 0.0478 0.0217 0.0314 0.0138 -0.1113 -0.1192 -0.0286 -0.0674 -0.0165 -0.0097 0.0663 -0.1072 0.0048 -0.1062 0.0677 -0.0028 0.0809 0.0119 0.1111 0.0363 0.0877 0.0189 0.0396 0.0358 -0.0257 0.0966 0.0951 -0.1179 -0.0906 -0.0619 -0.0229 -0.1193 0.0254 0.0110 0.0400 0.0655 0.1200 -0.0940 0.0728 0.0882 -0.1049 0.0939 0.0041 -0.0711 0.0914 -0.0461 0.0109 -0.0800 -0.0766 -0.0265 -0.0381 -0.0433 0.0193 0.0812 0.0163 0.0358 -0.0053 -0.0900 -0.0037 0.1009 0.1084 0.1006 -0.1237 -0.1227 0.0808 -0.0083 0.0376 0.0424 -0.1121 0.0379 0.0457 0.0443 -0.0528 0.0220 -0.0690 0.0620 -0.0660 -0.1124 0.1238 0.1188 0.0121 0.0574 0.1246 0.1000 -0.1034 0.0387 0.0307 -0.0669 -0.0619 -0.0819 0.0566 0.0150 0.0271 -0.0843 -0.0209 -0.0957 -0.1174 0.1031 -0.1250 0.0180 -0.0449 0.0920 0.1114 0.0604 -0.0987 0.0378 -0.0088 -0.0471 0.0549 -0.1234 0.1069 -0.0567 0.0241 -0.0163 0.0585 0.0199 -0.0188 0.0265 -0.0673 0.0697 -0.1224 0.1042 -0.0697 0.0695 0.0575 -0.1156 0.0663 0.1177 0.0562 -0.0417 -0.0054 0.0045 0.0614 -0.0089 0.0203 -0.1049 -0.1201 -0.0638 0.0728 0.0208 -0.1018 -0.0363 0.1128 -0.0524 0.0992 0.0937 -0.0378 -0.0195 -0.0188 -0.0483 0.0779 -0.0754 0.0148 -0.0060 0.0743 -0.0820 -0.0673 -0.1153 -0.1039 0.1002 0.1217 -0.0797 0.0217 0.1129 0.0951 0.0616 -0.1183 -0.0252 -0.0304 0.1234 -0.0538 0.0367 0.0407 0.1176 -0.0902 -0.0805 0.0111 -0.0863 -0.1222 -0.0678 -0.0044 -0.1218 0.0300 0.0739 -0.1152 0.1235 -0.0317 0.0685 0.0598 0.1120 -0.0902 0.1143 0.0801 0.0399 0.0360 -0.1152 -0.1007 -0.1126 0.0860 -0.0592 0.0955 0.0719 -0.1118 0.0839 -0.1176 0.0537 0.0078 0.1173 0.0129 -0.0301 0.0105 0.0961 0.1167 -0.0015 [torch.FloatTensor of size 256] PPP -0.0896 -0.0394 0.0575 0.0898 -0.0369 -0.0604 -0.1172 -0.0549 -0.0869 0.0679 0.0554 0.0323 0.1063 0.0728 0.0056 -0.0021 -0.0868 -0.0736 -0.1204 -0.0460 -0.0145 -0.0992 0.0601 0.0738 0.0064 -0.0570 -0.0947 0.0027 0.0669 0.0408 -0.0228 0.0554 0.0698 0.0994 0.0893 0.1066 0.1231 -0.0688 0.0152 -0.0445 -0.0341 -0.0329 0.1052 -0.0456 -0.0409 0.0484 0.0768 0.0061 0.0429 -0.0186 0.0379 -0.0657 -0.0839 0.0442 -0.0539 -0.0483 0.0572 -0.0753 -0.0779 -0.1166 0.0279 -0.0066 0.0854 0.0428 0.0903 -0.0658 0.1244 -0.0133 0.0524 0.0666 -0.0662 0.1046 -0.0649 0.1223 0.0819 -0.0074 0.0782 -0.0263 -0.0057 -0.0470 0.1029 0.1156 0.0884 0.0517 0.0135 0.0975 0.0406 0.0615 -0.1222 0.0127 0.0202 0.0154 -0.0490 0.0423 -0.0904 0.0034 0.0662 -0.0574 0.1162 -0.0481 -0.0147 0.0243 0.0805 0.0352 0.1058 0.0748 -0.0551 -0.0796 -0.1161 -0.0610 -0.0102 0.0143 0.0791 0.0752 0.0099 0.1133 -0.0766 0.0520 0.0810 0.1068 -0.0541 0.0390 0.1153 0.0095 0.0118 -0.0185 -0.1179 0.0452 0.0302 -0.0776 0.0909 -0.0086 0.0527 0.0133 0.1130 -0.0909 0.1160 0.1218 0.0347 -0.0277 0.0401 0.1104 -0.0635 -0.0656 -0.0928 -0.0365 0.0579 0.1197 -0.0098 -0.0489 -0.1086 0.0579 0.0282 -0.0649 0.0929 0.0039 0.0507 0.1174 0.0951 -0.0533 0.0641 0.0185 0.0011 -0.0621 0.0776 -0.0298 -0.1170 0.0693 0.0740 -0.0802 0.0799 -0.0972 -0.0010 0.0589 -0.0510 -0.0292 -0.0500 0.0838 -0.0176 0.0527 -0.0037 0.0092 0.0478 0.0512 -0.1239 0.0042 -0.0440 -0.0278 -0.0434 0.0052 0.0466 -0.0746 -0.1143 -0.0694 0.0201 0.0768 -0.0924 0.0589 -0.0591 -0.1036 0.0529 0.0197 -0.1067 -0.0165 -0.0370 0.0374 -0.0818 -0.0040 0.0659 0.1040 -0.0619 -0.1208 -0.1066 0.1142 0.0920 0.0833 0.0214 0.1020 -0.0266 -0.0508 0.0550 -0.0452 -0.0696 0.0879 0.0680 0.1009 -0.0232 0.0159 -0.1064 -0.0839 0.1089 -0.0473 -0.0158 0.0185 -0.1224 0.1131 0.1089 0.1030 -0.0451 -0.0555 -0.0767 -0.0546 0.0403 -0.1247 -0.0622 -0.0063 -0.0933 0.0445 0.0727 0.0664 -0.0864 -0.0978 0.0016 -0.1126 0.0716 0.0169 [torch.FloatTensor of size 256] PPP -6.6907e-02 -1.1469e-01 6.4129e-02 ... 3.8876e-02 -4.4813e-02 4.7873e-02 1.0064e-01 -1.2048e-01 7.3207e-02 ... -1.2326e-02 -1.1054e-01 -1.1371e-01 -9.9514e-02 -4.0268e-04 7.1349e-03 ... -1.0321e-01 -1.2389e-01 -4.2875e-03 ... ⋱ ... 6.1065e-02 -5.2070e-02 -7.4900e-02 ... 3.0900e-02 5.6731e-02 1.0931e-01 -4.2554e-03 1.2137e-01 -1.0776e-02 ... -9.8254e-03 -3.8701e-02 -2.6478e-02 -6.6246e-02 4.3564e-02 4.7540e-02 ... -8.6700e-02 -6.5478e-03 -7.8267e-02 [torch.FloatTensor of size 256x64] PPP -9.3750e-02 -8.5315e-02 -3.2224e-02 ... 4.6174e-02 1.2341e-01 7.0605e-02 -1.0107e-01 -1.1443e-01 -1.2133e-01 ... -1.1138e-01 7.7709e-02 4.1309e-02 -1.0675e-01 -9.5286e-02 8.1566e-02 ... -5.4656e-02 -2.9437e-02 -3.4233e-02 ... ⋱ ... 1.0409e-01 6.9673e-02 6.2664e-02 ... -3.2450e-02 -7.9281e-02 1.1497e-01 -2.8081e-02 -1.2337e-01 6.9056e-02 ... -1.0816e-01 -8.9076e-02 5.8901e-02 6.1354e-02 -2.9104e-02 -5.5389e-02 ... -3.9486e-02 -2.9318e-02 1.1121e-01 [torch.FloatTensor of size 256x64] PPP -0.0661 0.0039 0.0343 -0.0428 -0.0931 0.0150 0.0667 -0.0503 0.1009 0.0786 0.0435 -0.0952 0.0759 -0.0155 -0.0651 -0.0916 0.1066 0.0204 -0.0731 0.1241 0.0861 -0.0129 -0.0326 -0.0626 -0.1194 0.0683 -0.0699 -0.0822 0.0856 -0.0142 -0.0683 -0.1223 -0.0443 -0.1215 0.0422 0.0083 0.0220 -0.1037 0.0534 0.0914 -0.0479 -0.0273 0.0670 -0.0777 0.0030 0.0343 -0.1053 -0.0880 -0.0184 0.0800 -0.0517 -0.0596 -0.0919 0.0129 0.0592 0.0903 0.0144 -0.0522 -0.0801 -0.0489 0.0093 -0.0173 -0.0433 -0.0887 0.1231 -0.0524 -0.0295 -0.0432 -0.0109 -0.0625 0.0006 -0.0658 0.0526 -0.0297 0.0765 -0.0805 0.0268 -0.0250 -0.0652 -0.1201 -0.1215 -0.0732 0.0856 -0.0101 -0.1052 -0.0456 -0.0750 -0.1149 0.0586 0.0594 0.1186 0.0742 0.0826 0.0612 0.0535 0.0827 0.1247 -0.0917 0.0162 0.0731 -0.0980 -0.0508 0.1217 -0.0242 0.0939 0.0172 0.1151 0.0706 -0.1080 -0.1144 -0.0062 0.1227 0.0040 0.0451 0.0370 0.0963 -0.0548 0.0073 0.0590 -0.0860 0.0873 0.0123 0.0907 -0.0206 0.0959 0.1026 0.0361 0.0632 -0.0422 0.0934 -0.1055 -0.1022 0.0365 -0.0169 -0.0298 0.0096 0.0932 -0.0130 -0.0151 0.0693 -0.0823 -0.0176 0.0714 -0.0319 0.0251 0.0878 -0.0841 -0.0804 0.0915 0.0282 0.0470 -0.0592 -0.0913 -0.1234 0.0315 0.0182 -0.0110 0.0275 -0.0983 0.0250 -0.0442 -0.0113 -0.0569 0.0902 0.0690 0.0543 -0.0904 0.0373 0.0728 -0.1175 -0.0886 -0.0702 -0.0567 -0.0740 0.1204 -0.0247 -0.0659 0.0075 0.0327 0.0215 0.0539 -0.1142 -0.0042 0.0156 -0.1102 0.0036 0.0363 -0.0509 -0.0219 -0.0764 0.1240 -0.0074 0.0395 0.0058 -0.0012 0.0614 0.0985 0.0915 -0.0060 -0.0268 0.1034 0.1116 0.0221 0.1064 -0.0271 0.0554 0.0099 -0.0627 -0.0422 0.0102 -0.0310 0.0050 -0.0806 0.1235 -0.0786 -0.1168 -0.1148 0.0717 -0.1048 0.0509 0.0219 0.0902 -0.0821 -0.0005 0.0549 -0.0563 -0.0460 -0.0904 -0.0209 0.0030 -0.1225 -0.1071 -0.0584 -0.0711 -0.0749 -0.1088 -0.0597 -0.0829 0.0858 -0.0987 -0.0564 -0.0063 0.0432 -0.1095 -0.0563 0.0691 -0.0815 -0.0858 0.1200 0.0459 0.0008 0.0818 -0.0996 -0.0737 -0.0613 -0.0190 [torch.FloatTensor of size 256] PPP 0.0130 -0.0655 0.0321 -0.0441 0.0407 0.0434 -0.0885 0.1136 -0.0390 0.0391 -0.0185 0.1143 0.0910 0.0787 0.1237 0.0194 0.1165 0.0155 -0.0504 0.0776 -0.0269 0.0218 -0.0945 -0.0426 0.0947 -0.0057 0.1128 0.0760 -0.0732 -0.0685 -0.0252 0.0184 0.0505 0.0759 0.0615 -0.0737 0.0955 -0.0121 -0.0377 -0.0322 -0.1096 0.0560 -0.0542 0.0561 0.0817 -0.1046 -0.1038 0.0840 0.0799 -0.0957 -0.0016 0.0730 0.0618 0.0825 0.0690 -0.0078 -0.1246 0.0268 -0.0774 0.0724 -0.0090 0.0527 0.0685 0.0065 0.1016 0.0774 -0.0896 -0.1083 -0.0638 0.0117 0.0420 -0.0266 -0.1220 0.0789 0.1214 -0.1015 -0.0909 -0.0033 0.0222 0.0632 -0.0497 0.1060 -0.0510 -0.0921 0.0712 0.0647 0.0967 0.0060 -0.0525 0.1039 0.0658 -0.0608 0.0169 0.0928 -0.0088 -0.0515 0.1121 0.0269 -0.0597 0.0628 -0.0472 -0.1149 0.0278 -0.0011 -0.1209 -0.0417 -0.0575 -0.1082 -0.0024 -0.0415 0.0768 -0.0113 -0.0656 -0.1064 0.0836 -0.0422 0.0870 -0.1213 -0.1221 -0.0013 -0.0250 0.0287 0.0259 0.1054 -0.0570 0.0618 -0.0923 -0.0611 0.0055 0.0844 0.0405 0.1082 -0.0302 -0.1106 -0.0838 0.0420 0.0394 0.1039 0.0928 -0.1081 0.1234 -0.0382 -0.0146 0.0087 -0.1011 -0.0149 0.0597 0.0590 -0.0194 -0.0813 -0.0690 0.0264 -0.1082 -0.0783 0.0951 0.1159 -0.0691 0.0259 -0.0214 0.1139 -0.0472 0.0963 0.0718 0.1083 -0.1242 0.0716 -0.0109 0.0272 0.1071 -0.1237 0.0692 -0.0022 0.0654 0.1097 0.0385 0.0353 -0.0804 0.0428 0.0702 -0.1195 0.0169 -0.0206 0.1065 0.0441 0.0651 -0.0746 0.0194 -0.0477 0.0950 -0.0569 -0.0991 0.0898 -0.0652 0.0683 0.1220 -0.0222 -0.0751 0.0174 0.0994 0.0596 -0.1138 0.0801 -0.0527 0.0947 0.0996 0.0951 -0.0851 -0.0969 -0.0364 -0.0450 -0.0039 0.0870 -0.1237 -0.1074 0.0992 0.0800 -0.0711 0.0041 0.0270 -0.0486 -0.0652 -0.0523 -0.0862 -0.0883 -0.1182 -0.0350 -0.1132 0.0665 -0.0439 0.0392 0.0400 0.0344 -0.1176 -0.0682 -0.1236 0.0208 -0.1139 0.0633 -0.1106 0.0126 0.0185 -0.0219 0.1117 0.0977 0.0860 0.0608 0.0103 0.0771 -0.0751 0.0909 0.0020 -0.0930 0.0830 -0.0403 -0.0516 0.0852 [torch.FloatTensor of size 256] PPP Columns 0 to 9 0.0991 0.1218 -0.0816 0.0220 0.1029 0.0342 -0.0448 -0.0178 -0.0067 0.0853 0.1030 -0.0817 0.0258 0.0233 0.0885 -0.1076 0.0526 0.0402 0.0480 -0.1025 0.0224 -0.1067 0.0508 -0.0831 -0.0963 0.1152 -0.0994 -0.0305 -0.1041 -0.0282 -0.0365 -0.0857 -0.0107 0.0929 -0.0940 -0.0774 -0.0135 -0.0096 0.1087 0.1086 0.0340 -0.0464 -0.1135 0.0084 -0.0820 -0.0957 0.0070 0.0113 0.0882 0.1237 0.0658 -0.1047 -0.1228 -0.0985 0.0482 0.1177 -0.0759 -0.0205 0.0492 -0.0698 -0.0384 0.0334 0.0953 0.1019 -0.1207 -0.0936 -0.0745 -0.0863 0.0533 0.0637 -0.0595 0.0473 -0.0147 0.0062 -0.0191 -0.1011 -0.0289 -0.0175 -0.0966 -0.0236 0.0033 0.0701 0.0546 0.0245 -0.0388 -0.0780 0.1232 0.0122 -0.0397 -0.0912 -0.1052 -0.0875 -0.0197 0.0015 0.1021 -0.0661 -0.0445 0.0846 -0.0606 -0.0982 Columns 10 to 19 0.1033 -0.0640 0.0401 0.0702 -0.0747 -0.0222 -0.0202 -0.1072 0.0767 0.0377 0.0887 0.1194 0.1097 0.0148 -0.0138 0.0688 0.0077 0.1012 0.0860 0.0938 -0.0802 -0.0107 0.1062 -0.0412 -0.0003 -0.0302 0.0076 -0.0905 0.0395 0.0955 -0.0888 -0.1035 0.0805 0.0047 -0.0107 0.1076 0.0193 -0.0615 -0.0366 0.0952 -0.0148 0.1075 -0.0537 -0.0461 -0.0562 0.0190 -0.1205 -0.0974 -0.1083 -0.0353 -0.0527 0.1049 -0.0480 0.0007 0.0755 -0.0399 0.0567 0.0688 0.0719 -0.0474 0.0052 -0.0320 0.0903 -0.0895 0.0861 -0.1100 -0.0788 -0.0094 -0.0595 0.0111 0.0535 -0.0790 -0.0736 -0.0512 0.0414 0.0372 -0.0638 -0.1041 -0.0484 -0.0755 0.1205 -0.0672 0.1016 0.0827 0.0972 -0.0551 -0.0410 -0.0551 -0.1206 -0.0395 -0.0214 0.0026 -0.0185 0.0001 0.0064 0.0982 0.0946 0.0116 -0.0024 -0.1074 Columns 20 to 29 0.0014 -0.0417 0.0009 0.0854 0.0269 -0.0232 0.0012 0.0069 0.1210 -0.0919 -0.0958 -0.1185 -0.1184 0.0191 0.0536 -0.0257 0.0315 -0.0092 0.1055 -0.1166 0.0894 -0.0709 0.0922 -0.0424 0.0420 -0.0950 -0.0118 -0.0910 -0.1123 0.0984 -0.0553 0.0978 0.0158 -0.0619 0.0885 -0.0976 0.1039 -0.0054 -0.0926 0.0064 0.1147 -0.0009 -0.0362 -0.0879 -0.0277 -0.1015 -0.1144 -0.0243 -0.1179 0.0933 -0.0904 -0.1183 0.0636 -0.0606 0.0001 -0.0374 -0.0823 -0.0881 -0.0811 -0.0672 0.0241 -0.0959 0.0423 -0.0978 -0.0285 0.0123 0.0488 0.0487 0.0176 0.0173 0.1008 0.0326 -0.0710 -0.1112 -0.0287 -0.0300 -0.0440 -0.0343 -0.0450 -0.1118 0.1113 -0.0555 0.0969 -0.0204 -0.0316 -0.0028 -0.0019 0.0290 -0.0231 0.0070 -0.0039 -0.0672 -0.0438 0.0368 0.0553 -0.0499 0.0267 -0.0649 0.0019 0.0879 Columns 30 to 39 0.1117 -0.0552 0.0605 0.0743 0.0197 -0.0904 0.0005 0.0353 -0.0751 -0.0130 0.0750 -0.1095 0.0277 0.1156 0.0949 -0.0796 0.1044 0.0500 0.1119 0.0033 -0.1121 0.0314 0.0501 0.0035 -0.1149 0.0623 0.0100 -0.0163 0.1058 0.0865 0.0800 -0.0530 -0.0353 0.0779 0.1238 -0.0200 -0.0272 0.0986 0.0196 -0.0383 -0.0122 -0.1203 0.0466 -0.0569 -0.1043 -0.0704 0.1004 0.0055 0.0543 -0.0131 -0.0977 -0.0751 0.0328 0.0662 -0.0501 0.1024 0.1224 -0.0401 0.0107 0.0433 0.0638 -0.1180 -0.0250 -0.1239 0.0566 0.0193 -0.0407 -0.0628 0.0466 -0.0568 0.0265 -0.1144 -0.0753 0.1054 -0.0994 0.1162 0.0292 0.0838 -0.0420 -0.0506 -0.0177 0.0262 -0.0189 -0.0819 -0.0847 -0.0090 -0.0930 0.1133 0.0611 -0.0546 0.0987 -0.0040 -0.0567 -0.0284 0.0951 -0.0739 0.0193 -0.0317 -0.0896 0.0663 Columns 40 to 49 0.0285 0.0341 0.1245 -0.0614 -0.0078 -0.0584 -0.0105 0.0094 0.0422 -0.0227 0.0398 0.1004 -0.0884 0.0318 -0.0911 -0.1213 -0.0907 -0.0738 -0.0523 -0.0317 -0.1230 0.0846 -0.0740 -0.0878 0.0250 0.0375 -0.0831 0.1182 -0.0754 -0.0871 -0.0256 0.0675 -0.0249 0.0952 -0.1188 -0.0273 0.0934 0.1209 0.0765 0.0063 0.0708 0.0393 0.0189 0.0350 -0.0329 0.1113 0.0110 -0.0083 -0.1152 -0.0735 0.0585 0.0925 0.0616 0.0478 0.0957 0.1038 0.0545 -0.0227 -0.1126 0.0958 0.1080 -0.1215 0.0274 0.0803 -0.1214 0.0364 0.0985 -0.0505 0.0941 -0.0675 -0.0153 0.1246 -0.0902 0.0092 0.1193 -0.1020 -0.0869 0.0396 0.1078 0.0155 0.1243 0.0651 -0.0685 -0.0275 -0.0058 0.0416 -0.0851 0.0398 0.0317 -0.0656 -0.0128 0.0311 -0.0837 -0.0885 -0.0965 0.0931 -0.0942 -0.0342 0.0851 0.0435 Columns 50 to 59 -0.0706 0.0740 0.0403 0.0486 0.0804 0.1016 0.0948 0.0042 -0.0204 -0.1151 0.1095 0.0921 -0.1028 0.0282 0.0878 0.0996 0.1205 -0.0796 -0.0634 -0.1172 0.1047 -0.0863 0.0562 0.0295 0.0177 -0.0250 0.0261 0.1133 0.0844 0.0866 -0.0407 0.0486 -0.1202 -0.1043 0.0989 0.0932 0.0133 0.0651 -0.1158 -0.0456 -0.1219 0.0920 0.0697 0.0927 0.1020 0.0391 0.0309 0.0199 0.0844 0.0428 -0.0501 0.0589 0.0111 -0.0826 0.0056 -0.0369 -0.0911 0.1175 -0.0292 0.0318 0.0445 0.1137 0.1123 -0.0716 0.0885 -0.0383 0.0276 0.0571 0.0976 0.0298 -0.1082 -0.1132 -0.0977 -0.0630 0.1066 0.0418 0.0862 -0.0329 -0.0949 -0.1048 0.0947 0.0587 -0.0304 0.0770 -0.0187 0.0003 -0.0628 -0.1068 0.1023 0.0669 -0.0424 -0.0686 -0.0745 -0.0949 -0.0700 0.1227 -0.0021 -0.1125 -0.1001 0.0545 Columns 60 to 63 0.0592 -0.0805 -0.0735 -0.0953 0.0493 -0.0285 0.0179 0.0019 0.0548 0.0819 -0.1057 0.0855 0.0880 -0.0224 0.0091 0.0845 0.0501 -0.0397 -0.0922 0.1050 0.0109 -0.1045 0.0098 -0.0755 0.1079 0.0461 0.0320 -0.0830 0.0902 0.0743 -0.0809 -0.0330 -0.0153 0.0420 0.0624 -0.1119 -0.0138 -0.0618 0.1001 0.0437 [torch.FloatTensor of size 10x64] PPP 0.0109 -0.0778 -0.0501 0.0163 0.0763 -0.0792 0.1141 -0.0127 0.0162 0.0808 [torch.FloatTensor of size 10] None
关于pytorch中LSTM的可以再这里查看pytorch之LSTM。
我打印出Lstm的参数,并将它们结合pytorch的官方文档pytorch之LSTM,发现其实LSTM的这些参数都是Variables,注意到这个例子里的w和b也不只有一对,而是有两对,因为LSTM的num_layers=2,当这个值为3时就会有3对,由这里我受到启发,在改变sru的layer后,也发生了变化。由此我得出结论循环神经网络并不是只有一个神经单元,而是可以有多个,之前我一直以为只有一个。
而sru中的参数也是以Variable的形式存在与整个模型中,可以被更新。