Différence entre numpy dot () et Python 3.5+ multiplication matricielle @
J'ai récemment déménagé à Python 3.5 et j'ai remarqué que le nouvel opérateur de multiplication matricielle (@) se comporte parfois différemment de l'opérateurnumpy dot . Dans l'exemple, pour les tableaux 3d:
import numpy as np
a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b # Python 3.5+
d = np.dot(a, b)
L'opérateur @
renvoie un tableau de forme:
c.shape
(8, 13, 13)
Alors que la fonction np.dot()
renvoie:
d.shape
(8, 13, 8, 13)
Comment puis-je reproduire le même résultat avec numpy dot? Existe-il d'autres différences significatives?
3 réponses
L'opérateur @
appelle la méthode __matmul__
du tableau, pas dot
. Cette méthode est également présente dans l'API de la fonction np.matmul
.
>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)
De la documentation:
matmul
diffère dedot
de deux façons importantes.
- la Multiplication par scalaires n'est pas autorisée.
- Les piles de matrices sont diffusées ensemble comme si les matrices étaient des éléments.
Le dernier point indique clairement que dot
et matmul
les méthodes se comportent différemment lorsque les tableaux 3D (ou dimensionnels supérieurs) sont passés. Citant de la documentation un peu plus:
Pour matmul
:
Si L'un des arguments est N-D, n > 2, Il est traité comme une pile de matrices résidant dans les deux derniers index et diffusé en conséquence.
Pour np.dot
:
Pour les tableaux 2-D, il est équivalent à la multiplication matricielle, et pour les tableaux 1-D au produit interne des vecteurs (sans conjugaison complexe). pour n dimensions, c'est un produit de somme sur le dernier axe de a et l'avant-dernier de b
La réponse de @ ajcr explique comment les dot
et matmul
(invoqués par le symbole @
) diffèrent. En regardant un exemple simple, on voit clairement comment les deux se comportent différemment lorsqu'ils fonctionnent sur des "piles de matriciels" ou des tenseurs.
Pour clarifier les différences, prenez un tableau 4x4 et renvoyez le produit dot
et le produit matmul
avec une pile de matrices 2x4x3 ou un tenseur.
import numpy as np
fourbyfour = np.array([
[1,2,3,4],
[3,2,1,4],
[5,4,6,7],
[11,12,13,14]
])
twobyfourbythree = np.array([
[[2,3],[11,9],[32,21],[28,17]],
[[2,3],[1,9],[3,21],[28,7]],
[[2,3],[1,9],[3,21],[28,7]],
])
print('4x4*4x2x3 dot:\n {}\n'.format(np.dot(fourbyfour,twobyfourbythree)))
print('4x4*4x2x3 matmul:\n {}\n'.format(np.matmul(fourbyfour,twobyfourbythree)))
Les produits de chaque opération apparaissent ci-dessous. Notez comment le produit dot est,
...un somme du produit sur le dernier axe de a et l'avant-dernier de b
Et comment le produit de la matrice est formé en diffusant la matrice ensemble.
4x4*4x2x3 dot:
[[[232 152]
[125 112]
[125 112]]
[[172 116]
[123 76]
[123 76]]
[[442 296]
[228 226]
[228 226]]
[[962 652]
[465 512]
[465 512]]]
4x4*4x2x3 matmul:
[[[232 152]
[172 116]
[442 296]
[962 652]]
[[125 112]
[123 76]
[228 226]
[465 512]]
[[125 112]
[123 76]
[228 226]
[465 512]]]
En mathématiques, je pense que le point dans numpy a plus de sens
Dot(a,b)_{i,j,k,a,b,c} = \sum_m a_{i,j,k,m}b_{a,b,m,c}
Car il donne le produit dot quand a et b sont des vecteurs, ou la multiplication matricielle quand a et b sont des matrices
Quant à l'opération matmul dans numpy, elle se compose de parties du résultatdot , et elle peut être définie comme
Matmul (a,B)_{i,j, k, C} = \ sum_m a_{i,j,k,m}b_{i,j,m,c}
Ainsi, vous pouvez voir que matmul(a,b) retourne un tableau avec une petite forme, ce qui a une consommation de mémoire plus petite et a plus de sens dans les applications. En particulier, en combinant avec diffusion , vous pouvez obtenir
Matmul (a,B)_{i, j, k, L} = \ sum_m a_{i, j, k, m} b_{j, m, L}
Par exemple.
À partir des deux définitions ci-dessus, vous pouvez voir les exigences pour utiliser ces deux opérations. Supposer a. shape=(s1,S2, s3,S4) et B. shape=(t1,t2,t3, t4)
-
Pour utiliser dot(A,B) Vous avez besoin de
1. **t3=s4**;
-
Pour utiliser matmul(A,B) Vous avez besoin de
- t3=s4
- t2 = S2 , ou l'un des t2 et s2 est 1
- t1 = s1 , ou l'un des t1 et s1 est 1
Utilisez le code suivant pour vous convaincre.
Exemple de Code
import numpy as np
for it in xrange(10000):
a = np.random.rand(5,6,2,4)
b = np.random.rand(6,4,3)
c = np.matmul(a,b)
d = np.dot(a,b)
#print 'c shape: ', c.shape,'d shape:', d.shape
for i in range(5):
for j in range(6):
for k in range(2):
for l in range(3):
if not c[i,j,k,l] == d[i,j,k,j,l]:
print it,i,j,k,l,c[i,j,k,l]==d[i,j,k,j,l] #you will not see them