手写体识别,是指计算机从纸张、照片、触摸屏或其他设备中接收并识别人手写的文字等信息的技术。它在我们的生活中有着广泛的应用,例如文档处理、移动设备输入、个性化签名、教育和辅助技术等等。
本教程提供对您提供的 PyTorch 代码的详细解释,逐步指导您完成使用全连接神经网络 训练模型以识别 MNIST 数据集中的手写数字。
1. 数据集下载预处理
mnist手写体数据集下载,存放到./dataset文件夹下
需要将数据集分训练集和测试集
为什么不将所有数据拿来训练?
在机器学习中,通常会将数据集划分成训练集和测试集。虽然直觉上可能认为应该使用所有数据进行训练,但将数据划分成不同的子集对于有效地开发模型至关重要,以下是一些原因:
1. 评估模型泛化能力:
测试集的主要目的是评估训练模型的泛化能力。泛化能力是指模型在新数据上表现良好的能力,这对于实际应用至关重要。
如果使用整个数据集进行训练,包括用于评估的数据,模型可能会简单地记住训练示例而无法泛化到新数据。这种现象称为过拟合。
通过使用单独的测试集,您可以评估模型在从未见过的数据上的表现如何,从而提供更真实的性能估计。
2. 防止过拟合:
过拟合是指模型过于针对特定的训练数据,捕获噪声和无关模式,而不是学习数据中的底层关系。
将数据分成训练集和测试集有助于通过引入验证集来防止过拟合。验证集用于训练过程中监控模型在未见过的数据上的性能。
当模型开始过拟合训练数据时,其在验证集上的性能通常会开始下降。这可以作为停止训练并防止进一步过拟合的早期预警信号。
3. 提高模型选择效率:
在训练过程中,您可能会尝试不同的超参数、模型架构或训练技巧。
使用单独的测试集可以让您在相同的数据上评估不同模型或训练配置的性能,从而客观地选择最适合任务的模型。
4. 保留数据用于未来评估:
在某些情况下,您可能需要保留部分数据用于未来的评估,例如随着时间的推移比较不同的模型或技术。
通过保留一个隔离的测试集,您可以确保即使总体数据集随着时间的推移增长或变化,您也始终拥有用于未来比较的一致且无偏倚的数据。
#训练集和验证集
training_set_full = datasets.MNIST('dataset/', train=True, transform=transforms.ToTensor(), download=True)
#测试集
test_set = datasets.MNIST('dataset/', train=False, transform=transforms.ToTensor(), download=True)
import torch
import torchvision
from torchvision import transforms
# 定义数据预处理操作
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.1307,), (0.3008,)) # 归一化
])
# 加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 将训练集划分为训练集和验证集
train_size = len(train_dataset)
val_size = int(0.1 * train_size)
train_idx = list(range(train_size))
val_idx = list(range(train_size - val_size, train_size))
random.shuffle(train_idx)
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx[:-val_size])
val_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx[-val_size:])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, sampler=val_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
数据集的可视化
SAMPLE_IMG_ID = np.random.choice(len(training_set))
junk = plt.imshow(training_set[SAMPLE_IMG_ID][0].squeeze(0), cmap='gray') # "squeeze" removes the first dimension (1,28,28) => (28,28)
junk = plt.title(training_set[SAMPLE_IMG_ID][1])
2. 模型定义
使用类定义一个模型,在初始化里面定义好各个层的定义为一些函数,模型的层与层之间的参数传递实际上就是一堆函数的嵌套。
整个模型相当于一个很多层的嵌套函数,但是函数里面有可调整的参数,这些会影响神经网络的输出结果,训练的过程就是调整这些参数,使得神经网络的输出可以很好地近似到真实的数据值。比如这个案例中是一个十分类问题,神经网络输入是图像的像素点集合本例中是28*28,输出的是10个概率值,表示预测的label的概率,其和为1.
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc1 = nn.Linear(14 * 14 * 64, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.fc2(x)
return x
3. 模型训练
定义一些超参数,例如 EPOCHS
指定训练的轮数,EVALUATION_FREQ
指定每隔多少个批次进行一次模型在验证集上的评估
外部循环用于迭代指定的训练轮数 (EPOCHS
)。内部循环将在每个 epoch 中处理训练集的批次数据。
# 循环遍历训练周期 (epoch)
for epoch in range(EPOCHS):
print(f'第 {epoch + 1} 个 epoch')
epoch_acc = [] # 保存每个 epoch 的准确率
training_acc_checkpoint, training_loss_checkpoint = [], [] # 临时保存训练过程中的精度和损失,用于计算平均值
# 遍历训练数据集中的每一个批次
for batch_idx, (data, labels) in enumerate(training_loader):
# 将数据和标签移动到指定设备 (CPU 或 GPU)
data, labels = data.to(device), labels.to(device)
# 评估模型,获得预测结果、准确率和损失值
predictions, acc, loss = evaluate(model, loss_function, data, labels)
training_acc_checkpoint.append(acc)
epoch_acc.append(acc)
training_loss_checkpoint.append(loss.item())
# 反向传播计算梯度
loss.backward()
# 更新模型参数
optimizer.step()
# 清空梯度 (与 optimizer 相关)
optimizer.zero_grad() # 或者 model.zero_grad() (如果所有模型参数都在优化器中)
# 周期性评估验证集
if batch_idx % EVALUATION_FREQ == 0:
# 计算并保存平均训练精度和损失
training_acc_lst.append(np.mean(training_acc_checkpoint))
training_loss_lst.append(np.mean(training_loss_checkpoint))
# 清空临时保存的训练过程数据
training_acc_checkpoint, training_loss_checkpoint = [], []
# 评估验证集 (进入评估模式并关闭梯度追踪)
model.train(mode=False) # 进入评估模式 (参考链接: https://stackoverflow.com/a/55627781/900394)
with torch.no_grad(): # 临时关闭梯度追踪
validation_acc_checkpoint, validation_loss_checkpoint = [], []
validation_predictions = [] # 保存用于之后展示结果的预测值
for val_batch_idx, (val_data, val_labels) in enumerate(validation_loader):
val_data, val_labels = val_data.to(device), val_labels.to(device)
# 评估单个验证批次
val_predictions, validation_acc, validation_loss = evaluate(model, loss_function, val_data, val_labels)
validation_loss_checkpoint.append(validation_loss.item())
validation_acc_checkpoint.append(validation_acc)
validation_predictions.extend(val_predictions) # 扩展 (append 会覆盖) 所有验证预测值
# 计算并保存平均验证精度和损失
validation_acc_lst.append(np.mean(validation_acc_checkpoint))
validation_loss_lst.append(np.mean(validation_loss_checkpoint))
# 重新进入训练模式
model.train(mode=True)
# 打印当前 epoch 的训练和验证结果
print(f'训练精度: {training_acc_lst[-1]:.2f}, 训练损失: {training_loss_lst[-1]:.2f}, 验证精度: {validation_acc_lst[-1]:.2f}, 验证损失: {validation_loss_lst[-1]:.2f}')
4. 模型保存
模型保存是将训练好的模型参数和结构存储到文件中的过程,以便以后加载和使用。这对于以下目的至关重要:
#保存模型
torch.save(model.state_dict(), 'my_model.pth')
5. 模型转换
什么是 ONNX?
ONNX(Open Neural Network Exchange,开放神经网络交换)是一种用于表示深度学习模型的开放格式。它允许将模型存储在文件中,并在不同框架(例如 PyTorch、TensorFlow、MXNet 等)之间共享和转换。
ONNX 的目标是简化深度学习模型的部署和共享。它使研究人员和开发人员能够专注于构建模型,而无需担心如何将其部署到特定平台或框架。
ONNX 的优点
使用 ONNX 有很多优点,包括:
所以我们现在基本了解了onnx是一个开放格式,很多框架的模型都可以转换到onnx,或者由onnx转出,转换成onnx可以移植
torch提供了模型转换的方法
其中维度必须要和模型的输入一致,我这里自定义的模型输入为28*28
#导出为onnx模型
dummy_input = torch.randn(1, 1, 784)
torch.onnx.export(model, dummy_input, "my_model.onnx", verbose=False)
使用netron软件查看onnx神经网络结构
netron是一个深度学习模型可视化库,其支持以下格式的模型存储文件:
netron并不支持pytorch通过torch.save方法导出的模型文件,因此在pytorch保存模型的时候,需要将其导出为onnx格式的模型文件,可以利用torch.onnx模块实现这一目标。
对于分析网络结构,以及各个层的参数量,整个可视化效果很好
netron有网页在线版和软件版本,直接拖拽模型加载即能可视化
6. 模型训练与验证结果
由于我们并没有使用卷积神经网络或者循环神经网络,模型的最高准确度也只能达到97%,一些使用了卷积层的神经网络拟合能力更加强大,可以达到99%的准确度
模型文件onnx
模型文件pytorch
模型转换
模型与训练文件