TensorFlow: Se souvenir de l'état LSTM pour la prochaine fournée (LSTM)
étant donné un modèle LSTM formé, je veux effectuer une inférence pour les pas de temps simples, i.e. seq_length = 1
dans l'exemple ci-dessous. Après chaque pas de temps, les États LSTM internes (mémoire et caché) doivent être mémorisés pour le prochain "lot". Pour le tout début de l'inférence, les États LSTM internes init_c, init_h
sont calculés à partir de l'entrée. Ils sont ensuite stockés dans un objet LSTMStateTuple
qui est transmis au LSTM. Pendant l'entraînement, cet état est mis à jour à chaque pas de temps. Toutefois, pour inférence je veux que le state
soit sauvegardé entre les lots, c'est-à-dire que les états initiaux n'ont besoin d'être calculés qu'au tout début et après que les États LSTM doivent être sauvegardés après chaque " lot " (n=1).
j'ai trouvé cette question liée StackOverflow: Tensorflow, meilleure façon de sauver l'état dans RNNs? . Cependant cela ne fonctionne que si state_is_tuple=False
, mais ce comportement sera bientôt déprécié par TensorFlow (voir rnn_cell.py ). Keras semble avoir une belle enveloppe pour faire stateful LSTMs possible mais je ne sais pas la meilleure façon d'atteindre ce dans TensorFlow. Cette question sur le GitHub TensorFlow est également liée à ma question: https://github.com/tensorflow/tensorflow/issues/2838
N'importe qui de bonnes suggestions pour construire un modèle stateful LSTM?
inputs = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")
num_lstm_layers = 2
with tf.variable_scope("LSTM") as scope:
lstm_cell = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)
self.lstm = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)
init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers
outputs = []
for step in range(seq_length):
if step != 0:
scope.reuse_variables()
# CNN features, as input for LSTM
x_t = # ...
# LSTM step through time
output, self.state = self.lstm(x_t, self.state)
outputs.append(output)
2 réponses
j'ai découvert qu'il était plus facile de sauvegarder l'état entier pour toutes les couches dans un conteneur.
init_state = np.zeros((num_layers, 2, batch_size, state_size))
...
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
ensuite, déballez-le et créez un tuple de LSTMStateTuples avant d'utiliser l'Api natif TensorFlow RNN.
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
for idx in range(num_layers)]
)
RNN passe dans l'API:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)
la variable state
sera alors attribuée au prochain lot comme un attribut.
Tensorflow, meilleure façon de sauver l'état dans RNNs? était en fait ma question d'origine. Le code ci-dessous est la façon dont j'utilise les tuples de l'état.
with tf.variable_scope('decoder') as scope:
rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
([
tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
], state_is_tuple = True)
state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]
for t in range(TIME_STEPS):
if t:
last = y_[t - 1] if TRAINING else y[t - 1]
else:
last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))
y[t] = tf.concat(1, (y[t], last))
y[t], state = rnn_cell(y[t], state)
scope.reuse_variables()
plutôt que d'utiliser tf.nn.rnn_cell.LSTMStateTuple
je crée juste une liste de listes qui fonctionne très bien. Dans cet exemple, Je ne sauve pas l'état. Cependant, vous auriez pu facilement créer un État à partir de variables et utiliser assign pour sauvegarder les valeurs.