综合技术

机器学习(八)-基于KNN分类算法的手写识别系统

微信扫一扫,分享到朋友圈

机器学习(八)-基于KNN分类算法的手写识别系统
0

1 项目介绍

基于k-近邻分类器(KNN)的手写识别系统, 这里构造的系统只能识别数字0到9。

数据集和项目源代码

  • 难点: 图形信息如何处理?

图像转换为文本格式

2 准备数据

将图像转换为测试向量

  • 训练集:

    • 目录trainingDigits
    • 大约2000个例子
    • 每个数字大约有200个样本;
  • 测试集

    • 目录testDigits
    • 大约900个测试数据。

将图像格式化处理为一个向量。我们将把一个32×32的二进制图像矩阵转换为1×1024的向量, 如下图所示,

import numpy as np
def img2vector(filename):
    """
    # 将图像数据转换为(1,1024)向量
    :param filename: 
    :return: (1,1024)向量
    """
    # 生成一个1*1024且值全为0的向量;
    returnVect = np.zeros((1, 1024))
    # 读取要转换的信息;
    file = open(filename)
    # 依次填充
    # 读取每一行数据;
    for i in range(32):
        lineStr = file.readline()
        # 读取每一列数据;
        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])
    return returnVect

3 实施 KNN 算法

对未知类别属性的数据集中的每个点依次执行以下操作, 与上一个案例代码相同:

(1) 计算已知类别数据集中的点与当前点之间的距离;

(2) 按照距离递增次序排序;

(3) 选取与当前点距离最小的k个点;

(4) 确定前k个点所在类别的出现频率;

(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。

def classify(inX, dataSet, labels, k):
    """
    :param inX: 要预测的数据
    :param dataSet: 我们要传入的已知数据集
    :param labels:  我们要传入的标签
    :param k: KNN里的k, 也就是说我们要选几个近邻
    :return: 排序的结果
    """
    dataSetSize = dataSet.shape[0]  # (6,2) 6
    # tile会重复inX, 把他重复成(datasetsize, 1)型的矩阵
    # print(inX)
    # (x1 - y1), (x2- y2)
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
    # 平方
    sqDiffMat = diffMat ** 2
    # 相加, axis=1 行相加
    sqDistance = sqDiffMat.sum(axis=1)
    # 开根号
    distances = sqDistance ** 0.5
    # print(distances)
    # 排序 输出的是序列号index,并不是值
    sortedDistIndicies = distances.argsort()
    # print(sortedDistIndicies)

    classCount = {}
    for i in range(k):
        voteLabel = labels[sortedDistIndicies[i]]
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
        # print(classCount)
    sortedClassCount = sorted(classCount.items(), key=lambda d: float(d[1]), reverse=True)
    return sortedClassCount[0]

4 测试算法

使用 k-近邻算法识别手写数字

  • 测试集里面的信息;

def handWritingClassTest(k):
    """
    # 测试手写数字识别错误率的代码
    :param k:
    :return:
    """
    hwLabels = []
    import os
    # 读取所有的训练集文件;
    trainingFileList = os.listdir('data/knn-digits/trainingDigits')
    # 获取训练集个数;
    m = len(trainingFileList)
    # 生成m行1024列全为0的矩阵;
    trainingMat = np.zeros((m, 1024))
    # 填充训练集矩阵;
    for i in range(m):
        fileNameStr = trainingFileList[i]    # fileNameStr: 0_0.txt
        fileStr = fileNameStr.split('.')[0]  # fileStr: 0_0
        classNumStr = int(fileStr.split('_')[0])    # (数字分类的结果)classNumStr: 0
        # 填写真实的数字结果;
        hwLabels.append(classNumStr)
        # 图形的数据: (1,1024)向量
        trainingMat[i, :] = img2vector("data/knn-digits/trainingDigits/%s" % fileNameStr)

    # 填充测试集矩阵;
    testFileList = os.listdir('data/knn-digits/testDigits')
    # 默认错误率为0;
    errorCount = 0.0
    # 测试集的总数;
    mTest = len(testFileList)
    # 填充测试集矩阵;
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorTest = img2vector("data/knn-digits/testDigits/%s" % fileNameStr)

        # 判断预测结果与真实结果是否一致?
        result = classify(vectorTest, trainingMat, hwLabels, k)

        if result != classNumStr:
            # 如果不一致,则统计出来, 计算错误率;
            errorCount += 1.0
            print("[预测失误]:分类结果是:%d, 真实结果是:%d" % (result, classNumStr))
    print("错误总数:%d" % errorCount)
    print("错误率:%f" % (errorCount / mTest))
    print("模型准确率:%f" %(1-errorCount / mTest))
    return errorCount


print(handWritingClassTest(2))
  • 效果展示

5 KNN算法手写识别的缺点

算法的执行效率并不高。

  • 每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次;
  • 需要为测试向量准备2MB的存储空间

有没有更好的方法?

  • k决策树就是k-近邻算法的优化版,可以节省大量的计算开销。
阅读原文...

微信扫一扫,分享到朋友圈

机器学习(八)-基于KNN分类算法的手写识别系统
0
SegmentFault博客

VirtualBox 6.0.6 Released with Support for Linux 5.0 and Linux 5.1 Kernels

上一篇

prototype与__proto__的区别

下一篇

评论已经被关闭。

插入图片

热门分类

往期推荐

机器学习(八)-基于KNN分类算法的手写识别系统

长按储存图像,分享给朋友