如何使用 scipy.spatial.KDTree.query_ball_point

我正在尝试使用 Kdtree 数据结构从数组中移除最近的点,最好不要 for 循环。


import sys


import time


import scipy.spatial


class KDTree:

    """

    Nearest neighbor search class with KDTree

    """


    def __init__(self, data):

        # store kd-tree

        self.tree = scipy.spatial.cKDTree(data)


    def search(self, inp, k=1):

        """

        Search NN

        inp: input data, single frame or multi frame

        """


        if len(inp.shape) >= 2:  # multi input

            index = []

            dist = []


            for i in inp.T:

                idist, iindex = self.tree.query(i, k=k)

                index.append(iindex)

                dist.append(idist)


            return index, dist


        dist, index = self.tree.query(inp, k=k)

        return index, dist


    def search_in_distance(self, inp, r):

        """

        find points with in a distance r

        """


        index = self.tree.query_ball_point(inp, r)

        return np.asarray(index)



import numpy as np

import matplotlib.pyplot as plt

import matplotlib.animation as animation

start = time.time()

fig, ar = plt.subplots()

t = 0

R = 50.0

u = R *np.cos(t)

v = R *np.sin(t)


x = np.linspace(-100,100,51)

y = np.linspace(-100,100,51)


xx, yy = np.meshgrid(x,y)

points =np.vstack((xx.ravel(),yy.ravel())).T

Tree = KDTree(points)

ind = Tree.search_in_distance([u, v],10.0)

ar.scatter(points[:,0],points[:,1],c='k',s=1)

infected = points[ind]

ar.scatter(infected[:,0],infected[:,1],c='r',s=5)


def animate(i):

    global R,t,start,points

    ar.clear()

    u = R *np.cos(t)

    v = R *np.sin(t)

    ind = Tree.search_in_distance([u, v],10.0)

    ar.scatter(points[:,0],points[:,1],c='k',s=1)

    infected = points[ind]

    ar.scatter(infected[:,0],infected[:,1],c='r',s=5)

    #points = np.delete(points,ind)

    t+=0.01

    end = time.time()

    if end - start != 0:

        print((end - start), end="\r")

        start = end

ani = animation.FuncAnimation(fig, animate, interval=20)

plt.show()  

但无论我做什么,我都无法让 np.delete 处理 ball_query 方法返回的索引。我错过了什么?


我想让红色点在点数组的每次迭代中消失。


慕盖茨4494581
浏览 201回答 1
1回答

红糖糍粑

您的points数组是一个 Nx2 矩阵。您的ind索引是行索引列表。你需要的是指定你需要删除的轴,最终是这样的:points = np.delete(points,ind,axis=0)此外,一旦删除索引,请注意下一次迭代/计算中丢失的索引。也许您想要一个副本来删除点和绘图,另一个副本用于您不从中删除的计算。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python