使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

MNIST 这里就不多展开了,我们上几期的文章都是使用此数据集进行的分享 。

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

文章插图
手写字母识别
EMNIST数据集Extended MNIST (EMNIST), 因为 MNIST 被大家熟知,所以这里就推出了 EMNIST,一个在手写字体分类任务中更有挑战的 Benchmark。此数据集当然也包含了手写数字的数据集
在数据集接口上,此数据集的处理方式跟 MNIST 保持一致,也是为了方便已经熟悉 MNIST 的我们去使用,这里着重介绍一下 EMNIST 的分类方式 。
分类方式
EMNIST 主要分为以下 6 类:
By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数
【使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络】By_Merge: 共 814255 张,47 类,与 NIST 相比重新划分类训练集与测试机的图片数
Balanced : 共 131600 张,47 类,每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张
Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张
Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张
MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)
这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等
使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

文章插图
 
代码实现手写字母训练神经网络由于EMNIST数据集与MNIST类似,我们直接使用MNIST的训练代码进行此神经网络的训练
import torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvision# 数据库模块import matplotlib.pyplot as plt# torch.manual_seed(1)# reproducibleEPOCH = 1# 训练整批数据次数,训练次数越多,精度越高,为了演示,我们训练5次BATCH_SIZE = 50# 每次训练的数据集个数LR = 0.001# 学习效率DOWNLOAD_MNIST = Ture# 如果你已经下载好了EMNIST数据就设置 False# EMNIST 手写字母 训练集train_data = https://www.isolves.com/it/ai/2021-09-08/torchvision.datasets.EMNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download = DOWNLOAD_MNIST,split = 'letters' )# EMNIST 手写字母 测试集test_data = torchvision.datasets.EMNIST(root='./data',train=False,transform=torchvision.transforms.ToTensor(),download=False,split = 'letters')# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# 每一步 loader 释放50个数据用来学习# 为了演示, 我们测试时提取2000个数据先# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000] / 255.test_y = test_data.targets[:2000]#test_x = test_x.cuda() # 若有cuda环境,取消注释#test_y = test_y.cuda() # 若有cuda环境,取消注释首先,我们下载EMNIST数据集,这里由于我们分享过手写数字的部分,这里我们按照手写字母的部分进行神经网络的训练,其split为letters,这里跟MNIST数据集不一样的地方便是多了一个split标签,备注我们需要那个分类的数据
torchvision.datasets.EMNIST(root: str, split: str, **kwargs: Any)root ( string ) –数据集所在EMNIST/processed/training.pt 和 EMNIST/processed/test.pt存在的根目录 。split(字符串)-该数据集具有6个不同的拆分:byclass,bymerge,balanced,letters,digits和mnist 。此参数指定使用哪一个 。train ( bool , optional )– 如果为 True,则从 中创建数据集training.pt,否则从test.pt.download ( bool , optional ) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中 。如果数据集已经下载,则不会再次下载 。transform ( callable , optional ) – 一个函数/转换,它接收一个 PIL 图像并返回一个转换后的版本 。例如,transforms.RandomCroptarget_transform ( callable , optional ) – 一个接收目标并对其进行转换的函数/转换 。可视化数据集然后我们可视化一下此数据集,看看此数据集什么样子
def get_mApping(num, with_type='letters'):"""根据 mapping,由传入的 num 计算 UTF8 字符"""if with_type == 'byclass':if num <= 9:return chr(num + 48)# 数字elif num <= 35:return chr(num + 55)# 大写字母else:return chr(num + 61)# 小写字母elif with_type == 'letters':return chr(num + 64) + " / " + chr(num + 96)# 大写/小写字母elif with_type == 'digits':return chr(num + 96)else:return numfigure = plt.figure(figsize=(8, 8))cols, rows = 3, 3for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(train_data), size=(1,)).item()img, label = train_data[sample_idx]print(label)figure.add_subplot(rows, cols, i)plt.title(get_mapping(label))plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")plt.show()


推荐阅读