Comprendre un pytorch LSTM simple

import torch,ipdb
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

rnn = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
input = Variable(torch.randn(5, 3, 10))
h0 = Variable(torch.randn(2, 3, 20))
c0 = Variable(torch.randn(2, 3, 20))
output, hn = rnn(input, (h0, c0))

ceci est l'exemple LSTM du docs. Je ne sais pas comprendre les choses suivantes:

  1. Qu'est-ce que la taille de sortie et pourquoi n'est-elle spécifiée nulle part?
  2. pourquoi l'entrée a-t-elle 3 dimensions? Que représentent 5 et 3?
  3. Qu'est-ce que 2 et 3 en h0 et c0, qu'est-ce que cela représente?

Edit:

import torch,ipdb
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

num_layers=3
num_hyperparams=4
batch = 1
hidden_size = 20
rnn = nn.LSTM(input_size=num_hyperparams, hidden_size=hidden_size, num_layers=num_layers)

input = Variable(torch.randn(1, batch, num_hyperparams)) # (seq_len, batch, input_size)
h0 = Variable(torch.randn(num_layers, batch, hidden_size)) # (num_layers, batch, hidden_size)
c0 = Variable(torch.randn(num_layers, batch, hidden_size))
output, hn = rnn(input, (h0, c0))
affine1 = nn.Linear(hidden_size, num_hyperparams)

ipdb.set_trace()
print output.size()
print h0.size()
<!-RuntimeError: matrices attendues, tenseurs 3D, 2D à

17
demandé sur Abhishek Bhatia 2017-07-11 01:41:15

3 réponses

la sortie du LSTM est la sortie de tous les noeuds cachés sur la couche finale.

hidden_size - le nombre de blocs LSTM par couche.

input_size - le nombre d'entités en entrée par pas de temps.

num_layers - le nombre de couches cachées.

Au total il y a hidden_size * num_layers LSTM blocks.

les dimensions d'entrée sont (seq_len, batch, input_size).

seq_len - le nombre de pas de temps dans chaque entrée flux.

Le caché et de cellules dimensions sont les suivantes: (num_layers, batch, hidden_size)

sortie (seq_len, lot, hidden_size * num_directions): tenseur contenant les caractéristiques de sortie (h_t) à partir de la dernière couche de la RNN, pour chaque t.

il y aura donc hidden_size * num_directions sorties. Vous n'avez pas initialisé le RNN pour être bidirectionnel donc num_directions est 1. Si output_size = hidden_size.

Modifier: vous pouvez changer le nombre de sorties en utilisant une couche linéaire:

out_rnn, hn = rnn(input, (h0, c0))
lin = nn.Linear(hidden_size, output_size)
v1 = nn.View(seq_len*batch, hidden_size)
v2 = nn.View(seq_len, batch, output_size)
output = v2(lin(v1(out_rnn)))

Remarque:: pour cette réponse, j'ai supposé que nous parlions seulement de LSTMs non bidirectionnels.

Source: PyTorch docs.

18
répondu cdo256 2017-07-11 02:08:29

Vous pouvez définir

batch_first = True

si vous voulez faire des entrées et des sorties fournies comme

(batch_size, seq, input_size)

je le sais aujourd'hui, de partager avec vous.

2
répondu zzuczy 2017-11-27 02:11:29

la réponse de cdo256 est presque correcte. Il se trompe en se référant à ce que hidden_size signifie. Il explique:

hidden_size-le nombre de blocs LSTM par couche.

mais vraiment, là, c'est une meilleure explication:

chaque couche sigmoïde, tanh ou d'état caché dans la cellule est en fait un ensemble de noeuds, dont le nombre est égal à la taille de la couche cachée. Par conséquent, chacun des "noeuds" de la cellule LSTM est en fait un groupe de noeuds normaux du réseau neuronal, comme dans chaque couche d'un réseau neuronal étroitement connecté. Par conséquent, si vous définissez hidden_size = 10, alors chacun de vos blocs LSTM, ou cellules, aura des réseaux neuronaux avec 10 noeuds en eux. Le nombre total de blocs LSTM dans votre modèle LSTM sera équivalent à celui de la longueur de votre séquence.

Ceci peut être vu en analysant les différences dans les exemples entre nn.LSTM et nn.LSTMCell:

https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

et

https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell

2
répondu Lsehovac 2018-07-19 17:18:30