import torch
from torch import nn
lstm = nn.LSTM(input_size=100,hidden_size=20,num_layers=2)
print(lstm)
x = torch.randn(10,3,100)
out,(h,c) = lstm(x)
print('out shape:',out.shape) #[10,3,20]
print('h shape:',h.shape) #[2,3,20]
print('c shape:',c.shape) #[2,3,20]