TensorFlow: Max d'un tenseur le long d'un axe
Ma question Est en deux parties connectées:
-
Comment calculer le max le long d'un certain axe d'un tenseur? Par exemple, si j'ai
x = tf.constant([[1,220,55],[4,3,-1]])
Je veux quelque chose comme
x_max = tf.max(x, axis=1) print sess.run(x_max) output: [220,4]
Je sais qu'il y a un
tf.argmax
et untf.maximum
, mais ni l'un ni l'autre ne donnent la valeur maximale le long d'un axe d'un seul tenseur. Pour l'instant j'ai une solution:x_max = tf.slice(x, begin=[0,0], size=[-1,1]) for a in range(1,2): x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))
Mais cela semble moins qu'optimal. Est-il une meilleure façon de le faire?
-
Étant donné les indices d'un
argmax
d'un tenseur, comment indexer dans un autre tenseur en utilisant ces indices? En utilisant l'exemple dex
ci-dessus, Comment puis-je faire quelque chose comme ce qui suit:ind_max = tf.argmax(x, dimension=1) #output is [1,0] y = tf.constant([[1,2,3], [6,5,4]) y_ = y[:, ind_max] #y_ should be [2,6]
Je sais que le tranchage, comme la dernière ligne, n'existe pas encore dans TensorFlow(#206).
Ma question Est: Quelle est la meilleure solution de contournement pour mon cas spécifique (peut-être en utilisant d'autres méthodes comme gather, select,etc.)?
Informations supplémentaires: je sais que
x
ety
vont être bidimensionnels tenseurs seulement!
1 réponses
Le tf.reduce_max()
l'opérateur fournit exactement cette fonctionnalité. Par défaut, il calcule le maximum global du tenseur donné, mais vous pouvez spécifier une liste de reduction_indices
, qui a la même signification que axis
dans NumPy. Pour compléter votre exemple:
x = tf.constant([[1, 220, 55], [4, 3, -1]])
x_max = tf.reduce_max(x, reduction_indices=[1])
print sess.run(x_max) # ==> "array([220, 4], dtype=int32)"
Si vous calculez l'argmax en utilisant tf.argmax()
, Vous pouvez obtenir les valeurs d'un tenseur différent y
en aplatissant y
en utilisant tf.reshape()
, convertir les indices argmax en indices vectoriels comme suit, et utiliser tf.gather()
pour extraire les valeurs appropriées:
ind_max = tf.argmax(x, dimension=1)
y = tf.constant([[1, 2, 3], [6, 5, 4]])
flat_y = tf.reshape(y, [-1]) # Reshape to a vector.
# N.B. Handles 2-D case only.
flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)
y_ = tf.gather(flat_y, flat_ind_max)
print sess.run(y_) # ==> "array([2, 6], dtype=int32)"