#AI挑战营第一站# MNIST手写数字识别模型训练及优化
[背景介绍]
在人工智能的浪潮中,边缘计算作为推动AI技术落地的重要力量,其在嵌入式系统中的应用日益广泛。在这个项目中,我们致力于开发一个轻量级的手写数字识别模型,以便在资源受限的嵌入式设备上运行。我们将使用经典的MNIST数据集,这是一个包含手写数字图像的大型数据集,每个图像都标有相应的数字标签。我们的目标是训练一个模型,能够准确地识别这些手写数字,并将其应用于嵌入式设备,如RV1106开发板。由于嵌入式设备的计算能力和内存资源有限,需要特别注意模型的体积和计算效率,以确保在嵌入式系统中的实时性和稳定性。
[环境设置和数据准备]
首先,我们需要准备训练环境,需要确保已经安装了PyTorch和其他必要的库,包括matplotlib和onnxruntime。在安装了Python和PyTorch库后,我们使用 `torchvision` 来下载MNIST数据集,并通过 `DataLoader` 进行加载。这个数据集包含了60,000个训练样本和10,000个测试样本,每个样本都是28x28像素的灰度图像,每张图片都有相应的标签。
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
def get_data_loader(is_train):
# 数据加载函数
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST("", is_train, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=15, shuffle=True)
[模型设计]
构建一个简单的前馈神经网络模型,每个隐藏层包含64个神经元,使用ReLU激活函数。输出层是一个包含10个神经元的softmax层,用于对10个数字进行分类。基于模型的轻量级和计算效率考虑,以便在嵌入式设备上运行。
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
# 神经网络层定义
self.fc1 = torch.nn.Linear(28 * 28, 64)
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 10)
def forward(self, x):
# 前向传播过程
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
[训练和优化]
使用MNIST数据集进行模型的训练和优化。训练过程中采用Adam优化器和负对数似然损失函数。模型在3个周期内进行训练,并在每个周期结束后评估其在测试集上的准确率。通过反复训练和优化,我们使模型在保持准确性的同时尽可能减小了参数量和计算复杂度,以适应嵌入式设备的限制。
def evaluate(test_data, net):
# 模型评估函数
n_correct = 0
n_total = 0
with torch.no_grad():
for (x, y) in test_data:
outputs = net.forward(x.view(-1, 28 * 28))
for i, output in enumerate(outputs):
if torch.argmax(output) == y[i]:
n_correct += 1
n_total += 1
return n_correct / n_total
def main():
train_data = get_data_loader(is_train=True)
test_data = get_data_loader(is_train=False)
net = Net()
print("模型初始准确率:", evaluate(test_data, net))
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
for epoch in range(3):
for (x, y) in train_data:
net.zero_grad()
output = net.forward(x.view(-1, 28 * 28))
loss = torch.nn.functional.nll_loss(output, y)
loss.backward()
optimizer.step()
print("第", epoch+1, "轮准确率:", evaluate(test_data, net))
[测试和结果]
在测试集上评估模型的准确率,确保其具有良好的泛化能力。
我们可以看到,在没有开始训练时,模型初始的准确率为0.1左右,但第一轮后准确率就直接达到了0.9471,第三轮就达到了0.9706,这个准确度相当不错了。
(忽略第2、3行的警告信息,原因是的电脑没有安装显卡)
[模型导出与测试]
将训练好的模型保存为.pth文件,并转换为ONNX格式,以便在RV1106开发板上进行部署。通过这些步骤,我们成功地开发并优化了一个适用于嵌入式设备的轻量级数字识别模型。
def save_model(net, filename):
# 保存模型为.pth文件
torch.save(net.state_dict(), filename)
def convert_to_onnx(net, input_size, filename):
# 转换模型到ONNX格式
model = net.eval() # 设置模型为评估模式
x = torch.randn(input_size, requires_grad=True) # 创建一个随机输入
torch_out = model(x)
# 导出模型
torch.onnx.export(model, # 运行模型
x, # 模型输入 (或一个元组对于多个输入)
filename, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
同时我们也可以重新加载导出的模型,测试导出模型是否能够正常工作;
import torch
import torch.onnx
import onnxruntime as ort
import numpy as np
from pytorch_number import Net, get_data_loader, evaluate
def load_model(filename):
# 加载.pth文件中的模型
net = Net()
net.load_state_dict(torch.load(filename))
net.eval()
return net
def test_model(model, test_data):
# 模型测试函数
accuracy = evaluate(test_data, model)
print(f".pth文件模型在测试集上的准确率: {accuracy * 100:.2f}%")
def load_onnx_model(filename):
# 使用onnxruntime加载onnx模型
ort_session = ort.InferenceSession(filename)
return ort_session
def test_onnx_model(ort_session, test_data):
n_correct = 0
n_total = 0
for (x, y) in test_data:
x = x.view(-1, 28 * 28).numpy().astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: x}
ort_outs = ort_session.run(None, ort_inputs)
predicted = np.argmax(ort_outs[0], axis=1)
n_correct += np.sum(predicted == y.numpy())
n_total += y.size(0)
accuracy = n_correct / n_total
print(f".onnx文件模型在测试集上的准确率: {accuracy * 100:.2f}%")
if __name__ == "__main__":
# 获取测试数据
test_data = get_data_loader(is_train=False)
# 加载.pth文件中的模型并测试
loaded_model = load_model("mnist_model.pth")
test_model(loaded_model, test_data)
# 加载.onnx文件中的模型并测试
onnx_model = load_onnx_model("mnist_model.onnx")
test_onnx_model(onnx_model, test_data)
输出内容如下所示,与我们在训练完成后的模型一致。
本帖最后由 luyism 于 2024-4-17 16:37 编辑