Use 2D matrix as indexes for a 3D matrix in numpy?
It seems you are using 2D
array as index array and 3D
array to select values. Thus, you could use NumPy's advanced-indexing
-
# a : 2D array of indices, b : 3D array from where values are to be picked up
m,n = a.shape
I,J = np.ogrid[:m,:n]
out = b[a, I, J] # or b[a, np.arange(m)[:,None],np.arange(n)]
If you meant to use a
to index into the last axis instead, just move a
there : b[I, J, a]
.
Sample run -
>>> np.random.seed(1234)
>>> a = np.random.randint(0,2,(3,3))
>>> b = np.random.randint(11,99,(2,3,3))
>>> a # Index array
array([[1, 1, 0],
[1, 0, 0],
[0, 1, 1]])
>>> b # values array
array([[[60, 34, 37],
[41, 54, 41],
[37, 69, 80]],
[[91, 84, 58],
[61, 87, 48],
[45, 49, 78]]])
>>> m,n = a.shape
>>> I,J = np.ogrid[:m,:n]
>>> out = b[a, I, J]
>>> out
array([[91, 84, 37],
[61, 54, 41],
[37, 49, 78]])
If your matrices get much bigger than 3x3, to the point that memory involved in np.ogrid
is an issue, and if your indexes remain binary, you could also do:
np.where(a, b[1], b[0])
But other than that corner case (or if you like code golfing one-liners) the other answer is probably better.