对于输入数据,通过统计与它距离最近的 k 个数据的标签,选择出现次数最多的分类,成为它的分类。
需要引入 Python 中的 numpy 以及 operator 包。
from numpy import *
import numpy as np
import operator
def createdataset():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
kNN 算法处理未知类别的数据集中的点的流程:
1. 计算已知类别的点与未知类别的点之前的距离(可以利用欧拉公式);
2. 将上一步计算得到距离从小到大排序;
3. 选出距离最小的前 k 个点;
4. 统计出这 k 个点对于的类别总数以及次数;
5. 选择上一步中出现次数最多的类别作为当前点的类别。
根据之前课程学习的经验,编程实现过程中尽量要做到计算向量化。
def classify0(inx, dataset, labels, k):
datasetsize = dataset.shape[0]
diffmat = tile(inx, (datasetsize, 1)) - dataset # tile([1,2],(3,2))==>[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2]]
sqdiffmat = diffmat**2
sqdistances = sqdiffmat.sum(axis=1)
distances = sqdistances**0.5 # 计算和已知数据的距离
sorteddistindicies = distances.argsort() # 将元素从小到大排序,返回索引的序列
classcount = {} # 记录结果
for i in range(k):
voteilabel = labels[sorteddistindicies[i]]
classcount[voteilabel] = classcount.get(voteilabel, 0) + 1 # get(key, default) 返回value,不存在时返回 default
sortedclasscount = sorted(classcount.iteritems(), key=operator.itemgetter(1), reverse=True) # 取 value 值排序
return sortedclasscount[0][0]
打开作者提供的数据文件 datingTestSet2.txt,我们可以发现其对应的数据格式如下:
- 每一行都是一条记录
- 第一列:每年获得的飞行常客里程数
- 第二列:玩视频游戏所耗时间百分比
- 第三列:每周消费的冰淇凌公升数
- 第四列:标签
40920 8.326976 0.953952 3
14488 7.153469 1.673904 2
26052 1.441871 0.805124 1
75136 13.147394 0.428964 1
38344 1.669788 0.134296 1
72993 10.141740 1.032955 1
35948 6.830792 1.213192 3
42666 13.276369 0.543880 3
67497 8.631577 0.749278 1
所以,我们首先要将数据以矩阵的方式读入。其中,我们将最后一项当成是分类。
def file2matrix(filename):
fr = open(filename)
arrayoflines = fr.readlines()
numberoflines = len(arrayoflines)
returnmat = zeros((numberoflines, 3))
classlabelvalue = []
index = 0
for line in arrayoflines:
line = line.strip()
listfromline = line.split('\t')
returnmat[index, :] = listfromline[0:3]
classlabelvalue.append(int(listfromline[-1])) # 转化成为int
index = index + 1
return returnmat, classlabelvalue
书本中,为了更直观的了解数据,还使用 Matplotlib 库进行画图,这里不继续介绍。
由于不同的数据取值的范围不同,为了减少某些值特别大或者特别小的影响,我们需要采取数值归一化,将取值范围减少到 0 到 1 之间。比如 采用
def autonorm(dataset):
minvals = dataset.min(0)
maxvals = dataset.max(0)
ranges = maxvals - minvals
normdataset = zeros(shape(dataset))
m = dataset.shape[0]
normdataset = dataset - tile(minvals, (m, 1))
normdataset = normdataset/tile(ranges, (m, 1))
return normdataset, ranges, minvals
- 注意:输入数据验证时,也应该使用上一步的均值进行归一化。
def classfyperson():
resultlist = ['not at all', 'is small doses', 'in large doses']
percenttats = float(raw_input("percentage of time spent playing video games ?"))
ffmiles = float(raw_input("frequent filer miles earned per year?"))
icecream = float(raw_input("liters of ice cream consumed per year?"))
datingdatamat, datinglabels = file2matrix('datingTestSet2.txt')
normmat, ranges, minvals = autonorm(datingdatamat) # 归一化数值
inarr = array([ffmiles, percenttats, icecream])
classfypersonresult = classify0((inarr-minvals)/ranges, normmat, datinglabels, 3)
print "You will probably like this person: ", resultlist[classfypersonresult - 1]
利用函数将 32*32 的二进制图像转换为一个向量。
def img2vector(filename):
returnvec = zeros((1, 1024))
fr = open(filename)
for i in range(32):
linestr = fr.readline()
for j in range(32):
returnvec[0, 32*i+j] = int(linestr[j])
return returnvec
- 利用 listdir 函数,处理从目录中批量读取文件名,之后再进行文件内容的读取。
def handwritingclasstest():
hwlabels = []
trainingfilelist = listdir('digits/trainingDigits')
m = len(trainingfilelist)
trainingmat = zeros((m, 1024))
for i in range(m):
filenamestr = trainingfilelist[i]
filestr = filenamestr.split('.')[0]
classnumstr = int(filestr.split('_')[0])
hwlabels.append(classnumstr)
trainingmat[i, :] = img2vector('digits/trainingDigits/%s' % filenamestr)
testfilelist = listdir('digits/testDigits')
mtest = len(testfilelist)
errorcount = 0.0
for i in range(mtest):
testfilenamestr = testfilelist[i]
testfilestr = testfilenamestr.split('.')[0]
testclassnumstr = int(testfilestr.split('_')[0])
testvect = img2vector('digits/testDigits/%s' % testfilenamestr)
result = classify0(testvect, trainingmat, hwlabels, 3)
if (testclassnumstr != result):
errorcount += 1
print 'the classifier came back with: %d, the real number is: %d.\n' % (result, testclassnumstr)
print 'the total error is %d, the total error rate is: %f\n' % (errorcount, errorcount/float(mtest))
- 该方法最简单且有效;
- 在使用该算法时,需要保存所有的训练数据,以及遍历测试数据与训练数据之间的关系。耗时长,空间大;
- 无法给出数据集的基础结构信息。