一、KNN算法简介
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
二、算法过程
1.读取数据集
2.处理数据集数据 清洗,采用留出法hold-out拆分数据集:训练集、测试集
3.实现KNN算法类:
1)遍历训练数据集,离差平方和计算各点之间的距离
2)对各点的距离数组进行排序,根据输入的k值取对应的k个点
3)k个点中,统计每个点出现的次数,权重为距离的导数,得到最大的值,该值的索引就是我们计算出的判定类别
三、代码实现及数据分析
import numpy as np import pandas as pd # 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None。 data = pd.read_csv("你的目录/Iris.csv",header=0) # 显示前n行记录。默认n的值为5。 #data.head() # 显示末尾的n行记录。默认n的值为5。 #data.tail() # 随机抽取样本。默认抽取一条,我们可以通过参数进行指定抽取样本的数量。 # data.sample(10) # 将类别文本映射成为数值类型 data["Species"] = data["Species"].map({"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2}) # 删除不需要的Id列。 data.drop("Id", axis=1, inplace=True ) data.drop_duplicates(inplace=True) ## 查看各个类别的鸢尾花具有多少条记录。 data["Species"].value_counts()
分析:首先读取数据集,如下图
最后一列为数据集的分类名称,但是在程序中,我们更倾向于使用如0、1、2数字来表示分类,所以对数据集进行处理,处理后的数据集如下:
然后采用留出法对数据集进行拆分,一部分用作训练,一部分用作测试,如下图:
#构建训练集与测试集,用于对模型进行训练与测试。 # 提取出每个类比的鸢尾花数据 t0 = data[data["Species"] == 0] t1 = data[data["Species"] == 1] t2 = data[data["Species"] == 2] # 对每个类别数据进行洗牌 random_state 每次以相同的方式洗牌 保证训练集与测试集数据取样方式相同 t0 = t0.sample(len(t0), random_state=0) t1 = t1.sample(len(t1), random_state=0) t2 = t2.sample(len(t2), random_state=0) # 构建训练集与测试集。 train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]] , axis=0)#截取前40行,除最后列外的列,因为最后一列是y train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0) test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0) test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)
实现KNN算法类:
#定义KNN类,用于分类,类中定义两个预测方法,分为考虑权重不考虑权重两种情况 class KNN: ''' 使用Python语言实现K近邻算法。(实现分类) ''' def __init__(self, k): '''初始化方法 Parameters ----- k:int 邻居的个数 ''' self.k = k def fit(self,X,y): '''训练方法 Parameters ---- X : 类数组类型,形状为:[样本数量, 特征数量] 待训练的样本特征(属性) y : 类数组类型,形状为: [样本数量] 每个样本的目标值(标签)。 ''' #将X转换成ndarray数组 self.X = np.asarray(X) self.y = np.asarray(y) def predict(self,X): """根据参数传递的样本,对样本数据进行预测。 Parameters ----- X : 类数组类型,形状为:[样本数量, 特征数量] 待训练的样本特征(属性) Returns ----- result : 数组类型 预测的结果。 """ X = np.asarray(X) result = [] # 对ndarray数组进行遍历,每次取数组中的一行。 for x in X: # 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。 dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1)) ## 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引。 index = dis.argsort() # 进行截断,只取前k个元素。【取距离最近的k个元素的索引】 index = index[:self.k] # 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数。】 count = np.bincount(self.y[index], weights= 1 / dis[index]) # 返回ndarray数组中,值最大的元素对应的索引。该索引就是我们判定的类别。 # 最大元素索引,就是出现次数最多的元素。 result.append(count.argmax()) return np.asarray(result)
#创建KNN对象,进行训练与测试。 knn = KNN(k=3) #进行训练 knn.fit(train_X,train_y) #进行测试 result = knn.predict(test_X) # display(result) # display(test_y) display(np.sum(result == test_y)) display(np.sum(result == test_y)/ len(result))
得出计算结果:
26
0.9629629629629629
得出该模型计算的结果中,有26条记录与测试集相等,准确率为96%
接下来绘制散点图:
#导入可视化所必须的库。 import matplotlib as mpl import matplotlib.pyplot as plt mpl.rcParams["font.family"] = "SimHei" mpl.rcParams["axes.unicode_minus"] = False #绘制散点图。为了能够更方便的进行可视化,这里只选择了两个维度(分别是花萼长度与花瓣长度)。 # {"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2}) # 设置画布的大小 plt.figure(figsize=(10, 10)) # 绘制训练集数据 plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="Iris-virginica") plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="Iris-setosa") plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="Iris-versicolor") # 绘制测试集数据 right = test_X[result == test_y] wrong = test_X[result != test_y] plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", marker="x", label="right") plt.scatter(x=wrong["SepalLengthCm"], y=wrong["PetalLengthCm"], color="m", marker=">", label="wrong") plt.xlabel("花萼长度") plt.ylabel("花瓣长度") plt.title("KNN分类结果显示") plt.legend(loc="best") plt.show()
程序运行结果如下:
四、思考与优化
①尝试去改变邻居的数量。
②在考虑权重的情况下,修改邻居的数量。
③对比查看结果上的差异。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持亿速云。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。