Trouver l'indice d'un tableau numpy dans une liste

import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]

je m'attends

foo.index( np.array([[1,2,3]]) ) 

retour

2

mais au lieu de cela j'obtiens

ValueError: La valeur de vérité d'un tableau avec plus d'un élément est ambigüe. Utilisez un.tout () ou A. tous ()

rien de mieux que ma solution actuelle? Ça semble inefficace.

def find_index_of_array(list, array):
    for i in range(len(list)):
        if np.all(list[i]==array):
            return i

find_index_of_array(foo, np.array([[1,2,3]]) )
# 2
18
demandé sur kmario23 2016-12-22 00:12:58

5 réponses

La raison de l'erreur ici est évidemment parce que numpy est ndarray remplace == retourner un tableau plutôt qu'un booléen.

AFAIK, il n'y a pas de solution simple ici. Ce qui suit fonctionnera aussi longtemps que l'

np.all(val == array) peu de travaux.

next((i for i, val in enumerate(lst) if np.all(val == array)), -1)

si ce bit fonctionne ou non dépend de façon critique de ce que sont les autres éléments dans le tableau et s'ils peuvent être comparés avec des tableaux vides.

11
répondu mgilson 2016-12-21 22:13:56

pour la performance, vous pourriez vouloir traiter seulement les tableaux NumPy dans la liste d'entrée. Ainsi, nous pourrions taper-vérifier avant d'aller dans la boucle et indexer dans les éléments qui sont des tableaux.

ainsi, une implémentation serait -

def find_index_of_array_v2(list1, array1):
    idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
    for i in idx:
        if np.all(list1[i]==array1):
            return i
2
répondu Divakar 2016-12-21 21:35:33

et celle-ci?

arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)

# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]

# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]

Sortie:

[2]

Remarque: Cela permettra également de travailler même si foo est une liste de. Je viens de le mettre comme un numpy array ici.

2
répondu kmario23 2016-12-22 16:54:23

Le problème ici (vous le savez probablement déjà, mais juste de le répéter), c'est que list.index fonctionne comme suit:

for idx, item in enumerate(your_list):
    if item == wanted_item:
        return idx

La ligne if item == wanted_item est le problème, car il convertit implicitement item == wanted_item booléen. Mais numpy.ndarray (sauf si c'est un scalaire) soulève cette ValueError puis:

ValueError: La valeur de vérité d'un tableau avec plus d'un élément est ambigu. Utilisez un.tout () ou A. tous ()

j'utilise généralement une couche mince (adaptateur) autour de numpy.ndarray chaque fois que j'ai besoin d'utiliser des fonctions python comme list.index:

class ArrayWrapper(object):

    __slots__ = ["_array"]  # minimizes the memory footprint of the class.

    def __init__(self, array):
        self._array = array

    def __eq__(self, other_array):
        # array_equal also makes sure the shape is identical!
        # If you don't mind broadcasting you can also use
        # np.all(self._array == other_array)
        return np.array_equal(self._array, other_array)

    def __array__(self):
        # This makes sure that `np.asarray` works and quite fast.
        return self._array

    def __repr__(self):
        return repr(self._array)

ces emballages minces sont plus chers que l'utilisation manuelle de certains enumerate boucle ou compréhension mais vous n'avez pas à ré-implémenter les fonctions python. En supposant que la liste ne contient que des tableaux numpy (sinon, vous devez en faire if ... else ... vérification):

list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]

Après cette étape, vous pouvez utiliser tous vos fonctions python sur cette liste:

>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1

ces enveloppes Ne sont plus des maquereaux, mais vous avez des enveloppes minces, donc la liste supplémentaire est assez petite. Donc, en fonction de vos besoins, vous pouvez garder le enveloppés liste et la liste et choisir sur laquelle les opérations, par exemple vous pouvez également list.count les tableaux identiques maintenant:

>>> list_of_wrapped_arrays.count(np.ones((3)))
2

ou list.remove:

>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]), 
 array([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]]), 
 array([ 1.,  1.,  1.])]

Solution 2: sous-classe et ndarray.view

ce l'approche utilise des sous-classes explicites de numpy.array. Il a l'avantage que vous obtenez toute la fonctionnalité builtin array-et modifiez seulement l'opération demandée (qui serait __eq__):

class ArrayWrapper(np.ndarray):
    def __eq__(self, other_array):
        return np.array_equal(self, other_array)

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]

>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]

>>> view_list.index(np.array([2,2,2]))
1

encore une fois vous obtenez la plupart des méthodes list de cette façon:list.remove,list.count en plus list.index.

cependant cette approche peut donner un comportement subtil si une opération utilise implicitement __eq__. Vous pouvez toujours re-interpréter est aussi simple que le tableau de numpy en utilisant np.asarray ou .view(np.ndarray):

>>> view_list[1]
ArrayWrapper([ 2.,  2.,  2.])

>>> view_list[1].view(np.ndarray)
array([ 2.,  2.,  2.])

>>> np.asarray(view_list[1])
array([ 2.,  2.,  2.])

Alternative: Le Remplacement Des __bool__ (ou __nonzero__ pour python 2)

au lieu de corriger le problème dans le __eq__ méthode vous pouvez aussi outrepasser __bool__ ou __nonzero__:

class ArrayWrapper(np.ndarray):
    # This could also be done in the adapter solution.
    def __bool__(self):
        return bool(np.all(self))

    __nonzero__ = __bool__

Encore une fois ce qui rend l' list.index fonctionne comme prévu:

>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1

mais cela va définitivement modifier plus de comportement! Par exemple:

>>> if ArrayWrapper([1,2,3]):
...     print('that was previously impossible!')
that was previously impossible!
2
répondu MSeifert 2016-12-22 19:04:26

Ceci devrait faire l'affaire:

[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]
0
répondu Mahdi Ghelichi 2018-01-18 02:50:58