sklearn 和从头开始的不同 Kmean 结果

我尝试比较sklearn包中和从头开始的 kmean 聚类结果。暂存代码如下所示:


import matplotlib.pyplot as plt

from matplotlib import style


style.use('ggplot')

import numpy as np


colors = 10 * ["g", "r", "c", "b", "k"]



class K_Means:

    def __init__(self, k=3, tol=0.001, max_iter=300):

        self.k = k

        self.tol = tol

        self.max_iter = max_iter


    def fit(self, data):


        self.centroids = {}


        for i in range(self.k):

            self.centroids[i] = data[i]


        for i in range(self.max_iter):

            self.classifications = {}


            for i in range(self.k):

                self.classifications[i] = []


            for featureset in data:

                distances = [np.linalg.norm(featureset - self.centroids[centroid]) for centroid in self.centroids]

                classification = distances.index(min(distances))

                self.classifications[classification].append(featureset)


            prev_centroids = dict(self.centroids)


            for classification in self.classifications:

                self.centroids[classification] = np.average(self.classifications[classification], axis=0)


            optimized = True


            for c in self.centroids:

                original_centroid = prev_centroids[c]

                current_centroid = self.centroids[c]

                if np.sum((current_centroid - original_centroid) / original_centroid * 100.0) > self.tol:

                    print(np.sum((current_centroid - original_centroid) / original_centroid * 100.0))

                    optimized = False


            if optimized:

                break


    def predict(self, data):

        distances = [np.linalg.norm(data - self.centroids[centroid]) for centroid in self.centroids]

        classification = distances.index(min(distances))

        return classification


但由于收敛质心不同,结果也不同。sklearn 的散点图:

https://img1.mukewang.com/64d1e8e00001728103990331.jpg

同时,上面代码的散点图:

https://img4.mukewang.com/64d1e8ed000175e705490410.jpg

我想知道临时代码中存在哪些错误。

慕码人8056858
浏览 112回答 1
1回答

斯蒂芬大帝

K 均值高度依赖于初始化条件,即均值的起点。scikit-learn 可以根据数据进行智能初始化。如果您仔细阅读文档,您可能可以配置 scikit-learn 的版本以匹配您自己的版本。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python