• torch.nn.LSTM()函数维度详解


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    lstm=nn.LSTM(input_size,                     hidden_size,                      num_layers)
    x                         seq_len,                          batch,                              input_size
    h0            num_layers× imes×num_directions,   batch,                             hidden_size
    c0            num_layers× imes×num_directions,   batch,                             hidden_size

    output                 seq_len,                         batch,                num_directions× imes×hidden_size
    hn            num_layers× imes×num_directions,   batch,                             hidden_size
    cn            num_layers× imes×num_directions,    batch,                            hidden_size

    举个例子:
    对句子进行LSTM操作

    假设有100个句子(sequence),每个句子里有7个词,batch_size=64,embedding_size=300

    此时,各个参数为:
    input_size=embedding_size=300
    batch=batch_size=64
    seq_len=7

    另外设置hidden_size=100, num_layers=1

    import torch
    import torch.nn as nn
    lstm = nn.LSTM(300, 100, 1)
    x = torch.randn(7, 64, 300)
    h0 = torch.randn(1, 64, 100)
    c0 = torch.randn(1, 64, 100)
    output, (hn, cn)=lstm(x, (h0, c0))

    >>
    output.shape torch.Size([7, 64, 100])
    hn.shape torch.Size([1, 64, 100])
    cn.shape torch.Size([1, 64, 100])
    ---------------------
    作者:huxuedan01
    来源:CSDN
    原文:https://blog.csdn.net/m0_37586991/article/details/88561746
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    被刷登录接口
    移动端布局方案
    容易遗忘的Javascript点
    java 笔记02
    java 笔记01
    C# 日常整理
    reac-native 0.61开发环境
    DOS命令收集
    vue整理日常。
    php7.1+apache2.4.x+mysql5.7安装配置(目前windows)
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11187387.html
Copyright © 2020-2023  润新知