Use advanced-indexing -
m,n = a.shape[1:] I,J = np.ogrid[:m,:n] a_max_values = a[idx, I, J] b_max_values = b[idx, I, J]
In general:
def argmax_to_max(arr, argmax, axis): """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)""" new_shape = list(arr.shape) del new_shape[axis] grid = np.ogrid[tuple(map(slice, new_shape))] grid.insert(axis, argmax) return arr[tuple(grid)]
Pretty little uncomfortable than such a natural operation should be, unfortunately.
To index an n dim array with an (n-1) dim array, we could simplify it a bit to give us a grid of indices for all axes, for example:
def all_idx(idx, axis): grid = np.ogrid[tuple(map(slice, idx.shape))] grid.insert(axis, idx) return tuple(grid)
Therefore, use it to index into input arrays -
axis = 0 a_max_values = a[all_idx(idx, axis=axis)] b_max_values = b[all_idx(idx, axis=axis)]
Divakar
source share