如何创建一个 3D (shape=m,n,o) 数组,沿特定(第三)轴有一个

假设我有一个形状为 (2, 2) 的二维数组,其中包含索引


x = np.array([[2, 0], [3, 1]])

我想要做的是创建一个形状为 (2, 2, 4) 的 3D 数组,其沿第三个轴的值为 1,它们的位置由 给出x,因此:


y = np.zeros(shape=(2,2,4))

myfunc(array=y, indices=x, axis=2)


array([[[0, 0, 1, 0],

        [1, 0, 0, 0]],

       [[0, 0, 0, 1],

        [0, 1, 0, 0]]])

到目前为止,我还没有找到任何索引方法。一个for循环可能能够做到这一点,但我肯定有一个更快,矢量化方法。


繁星淼淼
浏览 131回答 2
2回答

胡说叔叔

您正在寻找的是所谓的高级索引。要正确索引整数数组,您需要有一组广播到正确形状的数组。由于x已经与两个维度对齐,因此您只需要沿每个轴制作带有索引的二维数组。np.ogrid对此有所帮助,因为它创建了广播到正确形状的最小范围数组:a, b = np.ogrid[:2, :2]y[a, b, x] = 1结果ogrid等价于a = np.arange(2).reshape(-1, 1)b = np.arange(2).reshape(1, -1)或者a = np.arange(2)[:, None]b = np.arange(2)[None, :]你也可以写一个单行:y[(*tuple(slice(None, n) for n in x.shape), x)] = 1
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Java