当前位置: 首页 > news >正文

AI入门:从零开始实现手写数字识别(1)

AI入门:从零开始实现手写数字识别(1)

  • 前言
  • 技术要求
  • 声明
  • 机器学习(Machine Learning)
    • 简介
    • 机器学习的类型
    • 机器学习的基本流程
  • KNN算法
    • 模型简介
    • 基本工作流程
    • 算法的优缺点
    • 补充
    • 基于KNN算法实现手写数字识别
      • 数据准备
      • 数据预处理与模型训练
      • 预测数据
      • PCA降维与超参数优化
  • 问题与总结

前言

AI大模型已经渗透到我们生活的方方面面。就业市场上,AI大模型开发工程师是各家企业争抢的人才。很多人想学,却不知从何下手。一看到那堆术语就懵了——ML、DL、LLM、Agent、MCP、RAG……完全不知道该从哪开始。
从今天起,我打算开一个系列。我们从机器学习(ML)入手,由易到难实现手写数字识别。最后搭一个交互式网站,把不同算法的效果做可视化。

技术要求

开始之前,需要具备以下基础:

  1. 编程语言基础:具备基本的 python 语法基础。
  2. 数学基础:了解微积分,线性代数,概率论中的基本概念。
  3. AI辅助编程:后期网站开发将使用 Code Agent 辅助项目开发。

声明

这个系列主要是记录个人学习过程。写下来是为了厘清思路,顺便加深理解。
内容来源包括但不限于网络搜索、AI 生成、视频学习、各大论坛等
如果发现内容有误,恳请大佬在评论区指正,我会及时修改。学习中有疑问也欢迎私信或评论区讨论。
我们正式开始学习吧。

机器学习(Machine Learning)

简介

机器学习(Machine Learning,简称 ML)是人工智能(AI)的一个核心分支。简单来说,它是一门让计算机从数据中自动学习规律,并利用这些规律对未知数据进行预测或决策的科学,而无需人类为它编写明确的、死板的规则代码。

机器学习的类型

根据训练数据学习目标的不同,机器学习主要分为以下三种类型:

  1. 监督学习 (Supervised Learning)
  • 定义:监督学习是指使用带标签的数据进行训练,模型通过学习输入数据与标签之间的关系,来做出预测或分类。
  • 常见应用:分类:预测离散的类别(识别猫狗图片、判断邮件是否为垃圾邮件、疾病诊断)、回归:预测连续的数值(预测房价、预测明天的气温)。
  • 常用算法:线性回归、逻辑回归、决策树、随机森林、支持向量机(SVM)。
  1. 无监督学习 (Unsupervised Learning)
  • 定义:无监督学习使用没有标签的数据,模型试图在数据中发现潜在的结构或模式。
  • 常见应用:聚类:将相似的数据分到同一组(如:用户画像分群、异常检测)、降维:在保留核心信息的前提下减少数据的维度,便于可视化或计算(如:PCA 主成分分析)。
  • 常用算法:K-Means聚类、层次聚类、DBSCAN。
  1. 强化学习 (Reinforcement Learning)
  • 定义:强化学习通过与环境互动,智能体在试错中学习最佳策略,以最大化长期回报。每次行动后,系统会收到奖励或惩罚,来指导行为的改进。
  • 常见应用:自动驾驶,游戏AI。

机器学习的基本流程

  1. 数据收集:数据是机器学习的燃料,获取数据是机器学习必备的第一步。
  2. 数据预处理:对数据集中的缺失值,异常值等脏数据进行清洗,填补。并选择有助于模型学习的最相关特征。良好的数据决定模型效果的上限。
  3. 模型选择与训练:根据需求以及数据选择合适的机器学习模型,模型通过优化算法(如梯度下降等)最小化损失函数,拟合数据规律。
  4. 模型评估:使用测试集数据来检验模型的准确率,召回率等指标。
  5. 模型部署:将训练完成的模型部署到实际应用中。

KNN算法

模型简介

K 近邻算法(K-Nearest Neighbors,简称 KNN)是机器学习中最基础和直观的算法之一,可以用于分类或者回归。
K 近邻算法属于监督学习的一种,核心思想是通过计算待分类样本与训练集中各个样本的距离,找到距离最近的K 个样本,然后根据这 K 个样本的类别或值来预测待分类样本的类别或值。简单来说就是“物以类聚”。

基本工作流程

  1. 数据收集:与其他机器学习算法一致,数据收集是必备的第一步。
  2. 数据预处理: 在 KNN 算法中,对数据进行归一化进行处理是非常重要的,因为 KNN 算法的核心是计算距离,如果某一特征的量级远大于其他特征,会导致该特征对结果的影响非常显著,因此在模型训练前,需要对数据进行归一化处理,确保每个特征对距离的贡献是相同的。
  3. 模型训练:KNN 是一种惰性学习的算法,没有显式的训练阶段,实际的训练只是把训练集的数据存起来,真正的计算发生在预测阶段。
  4. 预测:在预测过程中,模型会计算输入的样本与训练集中每一个样本之间的距离(这里的距离有多种计算方式),将所有计算出的距离进行升序排序,再从中选取前 K 个样本,进行决策(根据分类还是回归有不同的决策方式)。

算法的优缺点

  • 优点
    • 简单直观:KNN 算法原理简单,符合直觉,易于理解。
    • 无需训练:KNN 是一种惰性学习的算法,没有显式的训练阶段。
    • 对数据分布无要求:KNN 不对数据的分布做任何假设,适用于各种类型的数据。
  • 缺点
    • 预测速度慢:KNN 算法每预测一个新样本,都需要和训练集中的全部样本都计算一次距离,计算复杂度高。
    • 对样本不平衡敏感:如果某个类别的样本数量特别多,它在 K 个邻居中占多数的概率就大,容易主导预测结果。
    • 维度灾难:当特征维度非常高(比如几千维)时,所有样本之间的距离都会变得差不多,导致 KNN 失效。

补充

  1. 距离:判断两个样本到底有多近需要计算两个样本之间的距离,常用的距离有:
  • 欧氏距离 (Euclidean Distance):最常用,即两点之间的直线距离。在二维平面上可以使用勾股定理进行计算,并由此可以推广到高维空间。
  • 曼哈顿距离 (Manhattan Distance):街区距离,就像在曼哈顿街道开车,只能沿着网格线走(绝对值距离之和)。在二维平面上表示为两点之间横坐标差值的绝对值加上纵坐标差值的绝对值。


2.决策方式:根据分类任务还是回归任务选择不同的决策方式:

  • 如果是分类任务:采用多数表决原则。这 K 个邻居中哪个类别最多,新样本就归为哪一类。
  • 如果是回归任务:采用平均值原则。将这 K 个邻居的目标数值求平均,作为新样本的预测值。

基于KNN算法实现手写数字识别

数据准备

本项目数据集选用MNIST数据集,该数据集于 1998 年由 Yann LeCun 等人发布,可以说是机器学习领域中的"Hello World"。数据集中包括 70000 张 28 × 28 像素 (单通道灰度图,像素值范围为0 - 255,0 表示纯黑,255 表示纯白)的图像,可以分为 60000 张训练集与 10000 张测试集。
首先获取数据集并保存到本地。

fromsklearn.datasetsimportfetch_openmlfromsklearn.model_selectionimporttrain_test_splitfromsklearn.neighborsimportKNeighborsClassifierfromsklearn.metricsimportaccuracy_scoreimportjoblibimportpandasaspdimportnumpyasnpimportmatplotlib.pyplotasplt# 获取数据,使用sklearn提供的fetch_openml方法# 下载 MNIST 数据集# (784个特征代表28x28的像素,version=1表示数据集版本号,as_frame=False表示返回的是NumPy数组)mnist=fetch_openml('mnist_784',version=1,as_frame=False,parser='auto')# 提取特征集与标签集X,y=mnist.data,mnist.target# X 形状为 (70000, 784),y 形状为 (70000,)# 保存为 NumPy 的 .npz 格式是速度最快、体积最小的选择。它可以将 X 和 y 打包压缩进一个文件中。np.savez_compressed('mnist.npz',X=X,y=y)

可以查看数据集中的前十张照片。

fig,axes=plt.subplots(2,5,figsize=(10,5))fig.suptitle("MNIST数据集的前十张示例图片",fontsize=16)fori,axinenumerate(axes.flat):# 将第i张图片(X[i])reshape成28×28的矩阵image=X[i].reshape(28,28)ax.imshow(image,cmap='gray')# 将图片对应的标签显示为标题ax.set_title(y[i])ax.axis('off')# 隐藏坐标轴# 绘制图片plt.tight_layout()plt.show()

运行结果如下:

数据预处理与模型训练

因为MNIST数据集的特征处理较为简单,所以数据处理与模型训练合并在一起。

# 定义训练并保存模型函数deftrain_model(x_data,y_target):# 数据归一化,因为数据范围在0-255之内,只需要归一化到0-1之间,与数据的分布无关,因此可以直接对整个数据集进行归一化x_data=x_data/255.0# 切分训练集与测试集X_train,X_test,y_train,y_test=train_test_split(x_data,y_target,test_size=10000,train_size=60000,random_state=6,stratify=y_target)# 创建KNN分类器knn_estimator=KNeighborsClassifier(n_neighbors=5)# 训练模型knn_estimator.fit(X_train,y_train)# 评估模型print("准确度: ",knn_estimator.score(X_test,y_test))# 保存模型joblib.dump(knn_estimator,'./my_model/knn_model.pkl')print("训练完成,模型已经保存")# 读取数据并训练data=np.load('mnist.npz',allow_pickle=True)X=data['X']y=data['y']train_model(X,y)

运行结果如下:

准确度: 0.9725 训练完成,模型已经保存

预测数据

使用已经训练好的模型对图片进行预测

# 预测数据defpredict_data(path):# 读取模型knn_estimator=joblib.load('./my_model/knn_model.pkl')# 读取图片img=plt.imread(path)# 显示图片plt.imshow(img,cmap='gray')plt.axis('off')plt.show()# 预测图片x=img.reshape(1,-1)#如果图片格式为PNG,则返回0-1之间的数值,如果图片格式为其他格式,则返回0-255之间的数值,需要归一化。pred=knn_estimator.predict(x)returnpredprint("预测结果为",predict_data('./KNN_MNIST_TEST/demo.png'))

运行结果如下:

预测结果为 ['2']

PCA降维与超参数优化

看到这里,大家肯定会有疑问:什么是PCA?什么是超参数优化?

  • PCA
    PCA(Principal Component Analysis,主成分分析)是统计学和机器学习中最常用的降维算法,可以用于解决KNN的维度灾难以及计算开销问题。PCA 通过正交变换,将原始高维数据投影到方差最大的少数主成分上,实现降维的同时保留数据最主要的变异信息。简单来说,PCA 就是通过某种算法,把原始数据中众多相关指标浓缩成少数几个互不相干的核心指标,在牺牲少量精度的前提下,大幅降低数据复杂度。
    以MNIST数据集为例,MNIST每张图像有28×28=784维,若训练集有60000个样本,每次预测都要计算60000次784维向量的距离,非常耗时。并且通过前面的实操大家可以观察到,MNIST图像有很多像素的数据是重复且多余的,比如图片的边缘部分有大量的黑色像素,这些黑色像素并没有提供与数字预测相关的信息。
  • 超参数优化
    首先说明一下什么是超参数,在机器学习中,参数可以分为两类:模型参数(Parameters) 和超参数(Hyperparameters)。模型参数是模型内部通过训练数据自动学习到的变量,例如线性回归的权重。而超参数是在模型训练开始前,人为设定的外部参数,例如 KNN算法里的K值。
    超参数优化(Hyperparameter Tuning) 的目的,就是寻找一组最佳的超参数组合,使得模型在测试集上表现良好,从而避免过拟合或欠拟合。
    寻找最佳超参数的标准方法是结合 交叉验证(Cross-Validation) 和 网格搜索(Grid Search)。
    • 交叉验证:将训练集分成 N 份(如 5 折),轮流用 4 份训练,1 份验证,取平均成绩。这能有效防止模型在特定的验证集上过拟合。
    • 网格搜索:预先设定好各个超参数的候选值列表,算法会尝试每个超参数,并通过交叉验证找出得分最高的那一组。

由于相关内容较多,我会单独写一个帖子来实现。

问题与总结

在本项目测试过程中,最开始我使用了自己在ps中绘制的数字图片,但是预测效果非常差,我使用cursor进行问题排查,cursor得出的结论如下:
MNIST 测试集上的高准确率并不代表模型在任意手写图片上都能有同样表现。测试集与训练集同源、分布一致,而自定义 PNG 在写法风格、笔画粗细和数字位置上与 MNIST 存在明显偏差,其中最关键的一点是:MNIST 数字已做居中处理,若预测时仅将图片 reshape 后直接输入,像素位置偏移会显著拉大与训练样本的距离。KNN 依赖像素级欧氏距离找最近邻,对位置和分布变化极为敏感,不学习抽象特征,因此在“分布外”数据上容易失效。因此,高测试准确率更多反映模型在标准数据集上的拟合效果,真实预测效果还取决于输入与训练数据的一致性,以及 KNN 算法本身的局限。
因此使用 KNN 算法本身并不是很适合用于实现通用的手写数字识别,本项目仅仅是作为机器学习入门,帮助大家熟悉机器学习的步骤,掌握数据的处理,了解 KNN 算法的原理以及实现过程。
欢迎大家在评论区交流学习心得,也欢迎大佬指出文章中的错误。
下一篇我会分享决策树算法并基于该算法实现手写数字识别。

http://www.cnnetsun.cn/news/3032329.html

相关文章:

  • SketchUp STL插件终极指南:免费快速实现3D打印的完整解决方案
  • AI中转平台选型:上手前值得确认的10个问题
  • 计算机毕业设计之超市会员积分管理系统
  • Slack 集成 Claude Tag 实操指南:四步配置流程与 ambient 模式详解
  • 三步掌握XHS-Downloader:从小红书内容收集到专业素材库的完整路径
  • 工装装修哪家好?广东工装优选创雅(广东)数码科技有限公司
  • 【计算机毕业设计案例】基于 Spring Boot 的高校教务请假管理系统的设计与实现 基于 Web 技术的学生线上请假审批管理系统的设计与实现(程序+文档+讲解+定制)
  • 呼市装修避坑指南,深耕本地 10 年的玉虎装饰,凭六大优势打动无数业主
  • AI合同管理“越用越懂你”,到底懂什么、怎么懂?
  • BloodHound:用图论挖出 Active Directory 里隐藏的攻击路径
  • 低预算车场方案:解析西安富平图科适用场景
  • GTA5线上小助手:终极免费工具完全指南 - 解锁洛圣都无限可能
  • Java毕设选题推荐:基于 B/S 架构的西点甜点线上商城系统的设计与实现 基于 Spring Boot 的烘焙食品线上售卖平台的设计与实现【附源码、mysql、文档、调试+代码讲解+全bao等】
  • CODESYS 国产紧凑型 PLC 选型与实操指南:Bronze100 系列硬件、软件、现场落地全解析
  • CAXA电子图版2023 详细图文安装教程(附安装包)CAXA电子图版安装教程
  • 计算机Java毕设实战-基于 SpringBoot+MVC 架构的教务综合管理系统的设计与实现 前后端分离模式下高校教务管理系统【完整源码+LW+部署说明+演示视频,全bao一条龙等】
  • 别等 Agent 上线后补评估:先用 DeepEval 写失败样本
  • 2-LangGraph-Graph核心API-图和状态
  • 微信数据解放:三步掌握你的聊天记录解密技巧
  • 计算机毕业设计之jsp基于Web的有机蔬菜销售网站的设计与实现
  • 067、自定义插件开发:API 接口设计、权限声明与发布流程
  • 终极指南:微信聊天记录解密与数据恢复的专业方案
  • Joy-Con Toolkit终极指南:如何解锁任天堂手柄的隐藏潜能
  • 【TEE从入门到精通及实战】61 梯度中毒防御:在SGX enclave中实现鲁棒聚合
  • 彻底解决显卡驱动冲突:DDU深度清理工具完全指南
  • 计算机毕业设计之基于微信小程序的宠物领养系统
  • Ctrl+Alt+Shift+V都用错了?IDEA快捷键认知盲区大起底,92%开发者漏掉这5个核心组合键
  • 从AI4S跨越至AI4E,工程教育的“算力底座”终于补齐!
  • openHAB Core:智能家居的底层框架,不卖产品只卖能力
  • 性能测试三剑客:JMeter、Locust 与 k6 的全面对比与选型指南