41

KNN 算法原理及代码实现

 5 years ago
source link: https://mp.weixin.qq.com/s/Ns35v3R-qT1XX_ZfYBc5Wg?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

在本文中,我们将讨论一种广泛使用的分类技术,称为K最近邻(KNN)。我们的重点主要集中在算法如何工作以及输入参数如何影响预测结果。

内容包括:

  • 何时使用KNN算法?

  • KNN算法原理

  • 如何选择K值

  • KNN算法伪码

  • Python实现KNN算法

  • 与scikit-learn比较

何时使用KNN算法

KNN算法可以同时应用到分类和回归问题。然而,KNN在实际应用中更多的是用于分类问题。为了更好的评价一个算法优劣,我们从以下三个重要的指标分析:

  1. 预测结果的可解读性

  2. 运行时间

  3. 预测准确性

KNN算法与逻辑回归(Logistic Regression)、决策树CART(Classification And Regression Tree,简称CART)、随机森林(Random Forest)对比如下表

eERr6rn.png!web

相对于其他算法,KNN算法在这三个评价指标表现相对均衡优异。

如何选择K值

让我们看一个简单的例子来理解KNN。下图中有三种点, 红色类 (3个带你)、 绿色类 (3个点)以及未知类别的某 蓝星

yQFZfqE.png!web

我们试图找到 蓝星 的所属类别,它要么属于 红色类 要么属于 绿色类 。算法KNN中的 K 指的是某点的K个用来投票的邻居,少数服从多数。K个邻居中投票最多的属性代表该点的属性。在本例子中我们将K设置为3,我们会给 蓝星

画一个圈围住最近的K=3个点。最终大概类似于下图

fuyMRnJ.png!web

我们看到 蓝星 最近的三个邻居都是 红色类 ,所以我们可以认为蓝星的类别是 红色类 。在本例中,对 蓝星 的分类是非常明显的,因为来自最近邻居的三张选票都投给了 红色类

。参数K的选择在该算法中非常重要。接下来我们将了解哪些因素需要考虑,以得出最佳的K。

如何选择K值

首先让我们理解K值到底如何影响KNN算法。如果我们

有很多蓝色点和红色点数据,使用不同K值,最终的分类效果大概如下图。我们发现随着K值的增大,分界面越来越平滑。

buUF7zY.jpg!web

一般在机器学习中我们要将数据集分为训练集和测试集,用训练集训练模型,再用测试集评价模型效果。这里我们绘制了不同k值下模型准确率。

fiAVny2.png!web

从上图中我们发现当 k=1k=无穷大 时,KNN的误差都很大。但是在某个点时能够将误差降低到最小。在本例中 k=10

KNN算法伪码

我们现在设计下KNN伪码:

一. 读取数据

二. 初始化k值

三. 为了得到测试数据的预测类别,对训练集每条数据进行迭代

  1. 计算测试数据与训练集中每一条数据的距离。这里我们选用比较通用的欧几里得距离作为距离的实现方法。

  2. 对距离进行升序排列

  3. 对排列结果选择前K个值

  4. 得到出现次数最多的类

  5. 返回测试数据的预测类别

用Python实现KNN算法

这里我们使用iris数据集来构件KNN模型。

from sklearn.datasets import load_iris
import numpy as np

#从sklearn中导入iris数据集
iris = load_iris()
X = iris.data
y = iris.target


#计算data1和data2的欧几里得距离。
def euclideanDistance(data1, data2):
    distance = 0
    #data1和data2长度一致,这里我们就使用data1的长度
    for x in range(len(data1)):
        distance += np.square(data1[x]-data2[x])
    return np.sqrt(distance)

    
#定义KNN模型
def KNN(X, y, testInstance, k):

    distances = dict()

    #三、为了得到测试数据的预测类别,对训练集每条数据进行迭代
    for idx, trainInstance in enumerate(X):
        #1. 计算测试数据与训练集中每一条数据的距离。
        dist = euclideanDistance(testInstance, trainInstance)
        distances[idx] = dist

    #2. 对距离进行升序排列
    sorted_d = sorted(distances.items(), key=lambda k:k[1])


    neighbors = []

    #3. 对排列结果选择前K个值
    for x in range(k):
        neighbors.append(sorted_d[x][0])


    classVotes = dict()

    #4. 得到出现次数最多的类
    for x in range(len(neighbors)):
        label = y[neighbors[x]]

        if label in classVotes:
            classVotes[label]+=1
        else:
            classVotes[label]=1



    #5. 返回测试数据的预测类别
    sortedVotes = sorted(classVotes.items(), key=lambda k:k[1], reverse=True)
    return (sortedVotes[0][0], neighbors)

我们先测试下

testInstance = [7.2, 3.6, 5.1, 2.5]
k = 1
predicted, neighbors = KNN(X, y, testInstance, k)
print('predicted:',predicted)
print('neighbors:',neighbors)

运行结果

    predicted: 2
    neighbors: [141]

现在我们将k设置为3.

k = 3
predicted, neighbors = KNN(X, y, testInstance, k)
print('predicted:',predicted)
print('neighbors:',neighbors)

运行结果

    predicted: 2
    neighbors: [141, 139, 120]
k = 5
predicted, neighbors = KNN(X, y, testInstance, k)
print('predicted:',predicted)
print('neighbors:',neighbors)

运行结果

    predicted: 2
    neighbors: [141, 139, 120, 145, 144]

与scikit-learn比较

from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X, y)

#scikit中输入的数据一般都为二维数组(矩阵)。
testdata = [testInstance]
print('predicted:', knn.predict(testdata))
print('neighbors', knn.kneighbors(testdata))

运行结果

    predicted: [2]
    neighbors (array([[0.6164414 , 0.76811457, 0.80622577]]), array([[141, 139, 120]]))

我们仅仅用了一个例子,当k=3时,从scikit库knn算法的运行结果与我们设计的完全一致。neigbors都是[141, 139, 120]

小节

KNN算法是最简单的分类算法之一,即便算法如此简单,模型表现极佳。 KNN算法也可用于回归问题。 与所讨论方法的唯一区别是使用最近邻居的平均值而不是最近邻居的投票。

往期文章

100G Python学习资料:从入门到精通! 免费下载

上百G文本数据集等你来认领|免费领取

一图教你高效入门数据科学

在校大学生如何用知识月入3000

为什么你要为2019,而不是2018做计划?     

我们应该成为“专才”还是成为“通才”?

史蛟:是什么同时导致了成功和热情?

2017年度15个最好的数据科学领域Python库

推荐系统与协同过滤、奇异值分解

机器学习之使用逻辑回归识别图片中的数字

应用PCA降维加速模型训练

对于中文,nltk能做哪些事情  

使用sklearn做自然语言处理-1

使用sklearn做自然语言处理-2

机器学习|八大步骤解决90%的NLP问题    

Python圈中的符号计算库-Sympy

Python中处理日期时间库的使用方法  

如何从文本中提取特征信息?

昨日财报

赞赏、点赞、转发、AD支持都是对大邓的认可和支持,希望大家在阅读后顺便帮大邓转发一下。

ZFJj6rZ.jpg!web

n6VJFnz.jpg!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK