python神经网络编程入门(一)—— 分类器

发布时间:2026/7/5 3:27:27
python神经网络编程入门(一)—— 分类器 写这篇专栏一来是为了帮自己把一些基础知识梳理得更扎实二来也是因为现在网上“调包侠”太多了——我不反对用现成的库快速解决业务问题毕竟实用主义没毛病但很多内容打着“AI”的旗号却只教人“from sklearn import xxx”看完还是一头雾水。所以我想换一种方式从最朴素的数学出发把原理掰开揉碎讲清楚。当然我自己的理解也未必完美欢迎各位讨论和指正。这个系列不追求高大上的框架我们只关心一件事机器到底是怎么“学”的。今天就从最简单的分类器开始连“神经网络”都算不上就是一条直线。问题给两类虫子找条分界线假设我们收集了两类昆虫的数据A类和B类每个虫子我们量了它的体宽x和体长y数据点大概长这样图1。肉眼一看这两类虫子好像可以用一条直线分开——A类在线的上方B类在线的下方。图1所以我们想找到这条直线的方程。为了简单我们假定直线过原点形式为为什么不加截距 b因为我们想用一条从原点出发的射线来分界如果 b≠0直线不过原点在这个具体问题里没有实际意义反而增加一个参数不利于初学者理解。况且数据本身的分布也支持我们这样做你回头看图1就明白了。我们的任务就是从已知数据中“学”出这个斜率 k使得直线能把两类虫子分开。第一次尝试随便猜个 k假如我们先拍脑袋令 k0.1画出来的直线如图2虚线。显然这条线太“躺”了几乎所有点都在它上面根本分不开。但我们不能靠“肉眼”去调得有一套机械化的步骤也就是算法。图2从错误中学习误差怎么用我们任选一个样本图3比如 B 类中宽度 x5长度 y1 的那个点。用当前的 k0.1 去计算预测长度真实长度是 1显然我们算得太小了误差为但注意我们并不希望直线正好穿过这个 B 类点因为这样它就在分界线上了而不是在线的下方。我们想要所有 B 类点严格在直线下方所以对 B 类点我们给它一个目标值比真实值稍大一点比如 1.1。这样只要直线经过 (5, 1.1)那真实点 (5, 1) 就在直线下方了。所以真正的误差同样的如果处理 A 类点我们会给一个比真实值稍小的目标值比如真实 y4目标取 3.9确保 A 类点在直线上方。图3关键一步误差如何变成 k 的调整量现在我们知道误差是 E我们希望调整 k 来减小这个误差。假设新的斜率为 kΔk那么新的预测值为我们希望 ynew恰好等于我们设定的目标值 t所以而当前预测 kx 与目标值的差就是误差 Et−kx因此于是我们得到更新公式也就是说误差越大、样本宽度越小我们就需要调整得越多。连续更新两个样本看看会发生什么让我们从 k0.1开始先处理 B 类样本 (5, 1)目标 t1.1计算得所以 k 更新为 0.22。直线如图4中的绿色虚线。接着我们再处理一个 A 类样本 (1, 4)目标 t3.9比真实小一点当前 k0.22k 一下子跳到 3.90这条直线图4红色实线几乎完美地经过了 A 类样本 (1,4)但此时 B 类样本全被压在直线下方分类效果完全被破坏了。图4问题出在哪因为每次更新都用当前的样本“完全修正”误差导致斜率朝新样本的方向过度偏转完全忘记了之前样本的信息。这就是我们常说的“过拟合”或“遗忘”。我们打个比方。假设你是一个正在学习做题的学生。第一步老师先给你看了一道“B类”的例题样本 5,1。你根据这道题琢磨了一下把做题的经验值斜率调整到了0.22。这时候你觉得“嗯有点感觉了”。第二步翻车现场紧接着老师又给你看了一道“A类”的新题样本 1,4。为了让你完美答对这道新题你直接把刚才总结的 0.22 全扔了掏出橡皮擦把经验值狠狠改成了3.90结果这下好了你确实能把“A类”这道新题解得完美无缺。可是回过头来老师再让你做“B类”那道题你发现自己完全不会了因为你的经验值3.90已经偏离B类题十万八千里了。我们的算法“记性太差”像个墙头草——完全被最后看到的那个样本“牵着鼻子走”。它以为只要让直线完美贴合最后一个点就算学好了结果直接把之前学会的所有规律都覆盖掉了。这就像考试前只背了最后一道题的答案结果试卷一出傻眼了。在机器学习里这种“刚学了新东西就把旧东西忘光光”的现象叫做“灾难性遗忘”也就是过拟合的一种极端表现。这种“学一个丢一个”的学习方法肯定是不行的。解决方案别太激进慢慢来我们想要的是每次只朝误差方向迈一小步而不是一步到位。给更新公式加一个“学习率” α0 α 1α 控制步长比如取 0.01。这样每次只移动一点点然后在所有样本上反复多次迭代斜率就会缓慢地朝着一个兼顾所有样本的方向移动。用这个策略我们跑 200 轮每轮遍历全部样本最终 k 会稳定在某个值附近比如 k≈1.19。画出来图5可以看到 A 类都在线上方B 类都在线下方分类效果很好。这里200轮的含义是将全部20个样本从头到尾完整遍历一遍算作1轮总共这样重复遍历200次。之所以不能只跑1轮就结束是因为引入学习率如0.01后每次更新的步长极其微小仅遍历一遍20次更新只能让斜率k从初始值0.1挪动到0.15左右离真正的最优值0.79还差得很远我们必须依靠成百上千次的重复遍历通过“蚂蚁搬家”般的微小步长累积足够多的调整量才能让参数平稳地收敛到最优位置而不是像激进更新那样一步跨过头导致彻底崩盘。图5结论我们从一条简单的直线 ykx出发通过计算单个样本的误差推导出斜率调整公式 ΔkE/x。但若每次完全修正模型会“忘本”。引入学习率后每次仅微调反复迭代最终模型能学到一个泛化能力较好的分界线。这个过程中我们其实已经触碰到了机器学习最核心的思想——梯度下降只不过这里导数很简单。后续我们会看到即使复杂的神经网络本质上也是在做类似的事情度量误差、反向传播、小幅更新参数。希望这篇能帮你理解“机器学习”到底在做什么。大家也可以自己尝试看看当数据不能用一条过原点的直线分开时该怎么办。下一期我将简单说明下并开始介绍神经网络。代码片段1计算初始斜率k00.1用 B 类样本 (5,1) 更新得到k10.22再用 A 类样本 (1,4) 更新得到k23.90# aggressive_update.py # 运行此文件可生成展示“激进更新”现象的图示无学习率完全修正 import matplotlib matplotlib.use(Agg) # 使用非GUI后端避免弹窗适合服务器或脚本运行 import matplotlib.pyplot as plt import numpy as np # 原始数据 # 数据格式(编号, 类别, 宽度, 长度) # 类别 A 和 B 分别代表两种昆虫我们需要找到一条直线 ykx 将它们分开 data [ (1, A, 1.0, 4.0), (2, A, 1.5, 3.5), (3, A, 2.0, 4.5), (4, A, 2.5, 3.0), (5, A, 1.0, 3.5), (6, A, 1.5, 4.0), (7, A, 2.0, 3.0), (8, A, 2.5, 4.0), (9, A, 1.5, 4.5), (10, A, 2.0, 3.5), (11, B, 3.0, 1.0), (12, B, 3.5, 1.5), (13, B, 4.0, 2.0), (14, B, 4.5, 2.5), (15, B, 5.0, 1.0), (16, B, 3.0, 2.0), (17, B, 3.5, 2.5), (18, B, 4.0, 1.5), (19, B, 4.5, 1.0), (20, B, 5.0, 2.0) ] # 数据分离 # 将 A 类和 B 类的宽度和长度分别存入不同的列表方便后续绘图 A_width, A_length [], [] # A 类的宽度和长度列表 B_width, B_length [], [] # B 类的宽度和长度列表 for _, cls, w, l in data: # 遍历每个样本忽略编号 if cls A: A_width.append(w) # 将宽度添加到 A_width A_length.append(l) # 将长度添加到 A_length else: # 类别为 B B_width.append(w) B_length.append(l) # 激进更新无学习率 # 我们假设初始斜率为 0.1即直线 y 0.1x k0 0.1 # 第一步使用 B 类样本 (宽度5.0, 长度1.0) 进行更新 # 为了确保 B 类点位于直线下方我们设定一个“目标值” target_b 1.1比真实值 1.0 略高 target_b 1.1 # 计算当前误差目标值 - 当前预测值 (k0 * 5.0) E_b target_b - k0 * 5.0 # 根据推导的公式 Δk E / x更新斜率 k1 k0 E_b / 5.0 # 得到 k1 0.22 # 第二步使用 A 类样本 (宽度1.0, 长度4.0) 进行更新 # 为了确保 A 类点位于直线上方设定目标值 target_a 3.9比真实值 4.0 略低 target_a 3.9 # 计算误差目标值 - 当前预测值 (k1 * 1.0) E_a target_a - k1 * 1.0 # 再次完全修正斜率 k2 k1 E_a / 1.0 # 得到 k2 3.90极大 # 在控制台输出三个斜率值便于观察变化 print(fInitial k {k0:.3f}) # 初始 0.100 print(fAfter B update k {k1:.3f}) # 更新 B 后 0.220 print(fAfter A update k {k2:.3f}) # 更新 A 后 3.900 # 绘图 # 创建一个图形窗口figure尺寸为 7x6 英寸分辨率 150 dpi fig, ax plt.subplots(figsize(7, 6), dpi150) # 准备 x 轴的数据只需两个点即可确定一条直线0 和 6 # 因为直线 y kx 过原点所以 x0 时 y0x6 时 y6k x_line np.linspace(0, 6, 2) # [0, 6] # ---- 绘制数据散点图 ---- # A 类红色 (#E63946)圆形大小 60半透明 ax.scatter(A_width, A_length, c#E63946, labelClass A, s60, alpha0.8) # B 类蓝色 (#457B9D)圆形大小 60半透明 ax.scatter(B_width, B_length, c#457B9D, labelClass B, s60, alpha0.8) # ---- 绘制三条不同斜率的直线 ---- # 1. 初始直线 y k0 * x点状线 (:)橙色 (#F77F00) ax.plot(x_line, k0 * x_line, c#F77F00, linestyle:, linewidth2, labelfInitial k{k0:.2f}) # 2. 更新 B 后的直线 y k1 * x虚线 (--)青色 (#2A9D8F) ax.plot(x_line, k1 * x_line, c#2A9D8F, linestyle--, linewidth2, labelfAfter B k{k1:.2f}) # 3. 更新 A 后的直线 y k2 * x实线 (-)红色 (#E63946) ax.plot(x_line, k2 * x_line, c#E63946, linestyle-, linewidth2, labelfAfter A k{k2:.2f}) # ---- 突出显示用于更新的两个样本点 ---- # B 样本 (5, 1) 用黑色方块标记大小 100 ax.scatter([5.0], [1.0], cblack, markers, s100, labelB sample (5,1)) # A 样本 (1, 4) 用黑色菱形标记大小 100 ax.scatter([1.0], [4.0], cblack, markerD, s100, labelA sample (1,4)) # ---- 设置坐标轴标签、标题、图例、网格等 ---- ax.set_xlabel(Width) # x 轴标签宽度 ax.set_ylabel(Length) # y 轴标签长度 ax.set_title(Aggressive updates (no learning rate) – forgetting) # 图表标题 ax.legend() # 显示图例 ax.grid(linestyle--, alpha0.3) # 显示网格线虚线透明度 0.3 ax.set_xlim(0, 6) # x 轴范围 0~6 ax.set_ylim(0, 7) # y 轴范围 0~7 # 调整布局防止标签被裁剪然后保存图片到当前目录 plt.tight_layout() plt.savefig(aggressive_update.png, bbox_inchestight) plt.close(fig) # 关闭图形释放内存 # 打印提示信息 print(Image saved as: aggressive_update.png)代码片段2使用相同的初始斜率k0.1用带学习率alpha0.01的方式在所有样本上迭代 200 轮得到最终斜率k_final绘制初始直线和最终直线直观展示分类效果A 在线上方B 在线下方# moderate_update.py # 运行此文件可生成展示“适度更新”最终直线的图示带学习率 import matplotlib matplotlib.use(Agg) # 使用非GUI后端避免弹窗适合服务器或脚本运行 import matplotlib.pyplot as plt import numpy as np # 原始数据 # 数据格式(编号, 类别, 宽度, 长度) # 类别 A 和 B 分别代表两种昆虫目标是找到一条直线 ykx 将它们分开 data [ (1, A, 1.0, 4.0), (2, A, 1.5, 3.5), (3, A, 2.0, 4.5), (4, A, 2.5, 3.0), (5, A, 1.0, 3.5), (6, A, 1.5, 4.0), (7, A, 2.0, 3.0), (8, A, 2.5, 4.0), (9, A, 1.5, 4.5), (10, A, 2.0, 3.5), (11, B, 3.0, 1.0), (12, B, 3.5, 1.5), (13, B, 4.0, 2.0), (14, B, 4.5, 2.5), (15, B, 5.0, 1.0), (16, B, 3.0, 2.0), (17, B, 3.5, 2.5), (18, B, 4.0, 1.5), (19, B, 4.5, 1.0), (20, B, 5.0, 2.0) ] # 数据分离 # 将 A 类和 B 类的宽度和长度分别存入独立的列表方便后续绘图 A_width, A_length [], [] # 存储 A 类的宽度和长度 B_width, B_length [], [] # 存储 B 类的宽度和长度 for _, cls, w, l in data: # 遍历每个样本忽略编号 if cls A: A_width.append(w) # 将宽度加入 A_width A_length.append(l) # 将长度加入 A_length else: # 类别为 B B_width.append(w) B_length.append(l) # 带学习率的训练函数 def train_with_lr(alpha0.01, epochs200, margin1.0): 使用在线梯度下降逐样本更新训练斜率 k。 参数 alpha : 学习率控制每次更新的步长默认 0.01 epochs : 训练轮数即完整遍历所有样本的次数默认 200 margin : 边界余量用于设定目标值保证分类间隔 返回 训练得到的最终斜率 k # 预先构建目标值列表每个样本对应一个目标值 # 对于 A 类目标值 真实长度 - margin略低于真实值使得真实点在直线上方 # 对于 B 类目标值 真实长度 margin略高于真实值使得真实点在直线下方 targets [] for _, cls, w, l in data: if cls A: targets.append((w, l - margin)) # A 类目标值稍低一点 else: targets.append((w, l margin)) # B 类目标值稍高一点 k 0.1 # 初始斜率同激进更新一样从 0.1 开始 # 外层循环遍历 epochs 轮 for _ in range(epochs): # 内层循环按顺序遍历所有样本顺序可以不变也可以打乱这里保持原样 for x, target in targets: pred k * x # 当前斜率下的预测值 error target - pred # 计算误差目标值 - 预测值 k alpha * error / x # 根据公式 Δk α * (error / x) 更新斜率 return k # 训练与验证 # 设定初始斜率 k0 用于绘图对比 k0 0.1 # 调用训练函数使用学习率 0.01训练 200 轮边界余量 1.0 k_final train_with_lr(alpha0.01, epochs200, margin1.0) # 输出最终斜率显示 6 位小数以便观察细微变化 print(fFinal k {k_final:.6f}) # 验证最终直线的分类效果统计错误分类的样本数 errors 0 for _, cls, w, l in data: pred k_final * w # 计算该样本在最终直线下的预测值 if cls A and l pred: # A 类应位于直线上方实际长度 预测值 errors 1 elif cls B and l pred: # B 类应位于直线下方实际长度 预测值 errors 1 print(fClassification errors: {errors} / {len(data)}) # 输出错误数/总数 # 绘图 # 创建画布尺寸 7x6 英寸分辨率 150 dpi fig, ax plt.subplots(figsize(7, 6), dpi150) # x 轴数据只需两个点0 和 6即可绘制直线因为 ykx 过原点 x_line np.linspace(0, 6, 2) # [0, 6] # ---- 绘制散点图 ---- # A 类红色 (#E63946)圆形大小 60半透明 ax.scatter(A_width, A_length, c#E63946, labelClass A, s60, alpha0.8) # B 类蓝色 (#457B9D)圆形大小 60半透明 ax.scatter(B_width, B_length, c#457B9D, labelClass B, s60, alpha0.8) # ---- 绘制直线 ---- # 初始直线 y k0 * x点状线 (:)橙色 (#F77F00)用于对比 ax.plot(x_line, k0 * x_line, c#F77F00, linestyle:, linewidth2, labelfInitial k{k0:.2f}) # 最终直线 y k_final * x实线青色 (#2A9D8F)较粗突出显示 ax.plot(x_line, k_final * x_line, c#2A9D8F, linewidth2, labelfFinal k{k_final:.3f}) # ---- 设置图表元素 ---- ax.set_xlabel(Width) # x 轴标签宽度 ax.set_ylabel(Length) # y 轴标签长度 ax.set_title(Moderate updates with learning rate α0.01) # 图表标题 ax.legend() # 显示图例 ax.grid(linestyle--, alpha0.3) # 网格线虚线透明度 0.3 ax.set_xlim(0, 6) # x 轴范围 0~6 ax.set_ylim(0, 7) # y 轴范围 0~7 # 调整布局保存图片关闭图形释放内存 plt.tight_layout() plt.savefig(moderate_update.png, bbox_inchestight) plt.close(fig) # 控制台提示图片已保存 print(Image saved as: moderate_update.png)