Régler la valeur unique à L'intérieur de Tensor - TensorFlow

je me sens gêné de demander cela, mais comment ajustez-vous une seule valeur à l'intérieur d'un tenseur? Supposons que vous vouliez ajouter " 1 " à une seule valeur dans votre tenseur?

le Faire par l'indexation ne fonctionne pas:

TypeError: 'Tensor' object does not support item assignment

une approche serait de construire un tenseur de 0 de forme identique. Et en ajustant un 1 à la position que vous voulez. Alors vous ajouteriez les deux tenseurs ensemble. Encore une fois, cela se heurte au même problème qu'auparavant.

j'ai lu via l'API docs plusieurs fois et ne semble pas comprendre comment faire cela. Merci à l'avance!

45
demandé sur LeavesBreathe 2016-01-08 23:58:51

2 réponses

mise à jour: TensorFlow 1.0 comprend un tf.scatter_nd() opérateur, qui peut être utilisé pour créer des delta ci-dessous sans créer un tf.SparseTensor.


C'est en fait étonnamment délicate avec l'ops! Peut-être que quelqu'un peut suggérer une meilleure façon de conclure ce qui suit, Mais voici une façon de le faire.

disons que vous avez un tf.constant() tenseur:

c = tf.constant([[0.0, 0.0, 0.0],
                 [0.0, 0.0, 0.0],
                 [0.0, 0.0, 0.0]])

...et si vous souhaitez ajouter 1.0 à l'emplacement [1, 1]. Une façon de le faire est de définir un tf.SparseTensor,delta, représentant la variation:

indices = [[1, 1]]  # A list of coordinates to update.

values = [1.0]  # A list of values corresponding to the respective
                # coordinate in indices.

shape = [3, 3]  # The shape of the corresponding dense tensor, same as `c`.

delta = tf.SparseTensor(indices, values, shape)

alors vous pouvez utiliser le tf.sparse_tensor_to_dense() op pour faire un dense tenseur de deltac:

result = c + tf.sparse_tensor_to_dense(delta)

sess = tf.Session()
sess.run(result)
# ==> array([[ 0.,  0.,  0.],
#            [ 0.,  1.,  0.],
#            [ 0.,  0.,  0.]], dtype=float32)
58
répondu mrry 2017-03-03 17:23:35

Que Diriez-vous de tf.scatter_update(ref, indices, updates) ou tf.scatter_add(ref, indices, updates)?

ref[indices[...], :] = updates
ref[indices[...], :] += updates

Voir .

6
répondu Liping Liu 2017-01-01 19:27:58