MNIST 这里就不多展开了,我们上几期的文章都是使用此数据集进行的分享 。
文章插图
手写字母识别
EMNIST数据集Extended MNIST (EMNIST), 因为 MNIST 被大家熟知,所以这里就推出了 EMNIST,一个在手写字体分类任务中更有挑战的 Benchmark。此数据集当然也包含了手写数字的数据集
在数据集接口上,此数据集的处理方式跟 MNIST 保持一致,也是为了方便已经熟悉 MNIST 的我们去使用,这里着重介绍一下 EMNIST 的分类方式 。
分类方式
EMNIST 主要分为以下 6 类:
By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数
【使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络】By_Merge: 共 814255 张,47 类,与 NIST 相比重新划分类训练集与测试机的图片数
Balanced : 共 131600 张,47 类,每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张
Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张
Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张
MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)
这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等
文章插图
代码实现手写字母训练神经网络由于EMNIST数据集与MNIST类似,我们直接使用MNIST的训练代码进行此神经网络的训练
import torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvision# 数据库模块import matplotlib.pyplot as plt# torch.manual_seed(1)# reproducibleEPOCH = 1# 训练整批数据次数,训练次数越多,精度越高,为了演示,我们训练5次BATCH_SIZE = 50# 每次训练的数据集个数LR = 0.001# 学习效率DOWNLOAD_MNIST = Ture# 如果你已经下载好了EMNIST数据就设置 False# EMNIST 手写字母 训练集train_data = https://www.isolves.com/it/ai/2021-09-08/torchvision.datasets.EMNIST(root='./data',train=True,transform=torchvision.transforms.ToTensor(),download = DOWNLOAD_MNIST,split = 'letters' )# EMNIST 手写字母 测试集test_data = torchvision.datasets.EMNIST(root='./data',train=False,transform=torchvision.transforms.ToTensor(),download=False,split = 'letters')# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# 每一步 loader 释放50个数据用来学习# 为了演示, 我们测试时提取2000个数据先# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000] / 255.test_y = test_data.targets[:2000]#test_x = test_x.cuda() # 若有cuda环境,取消注释#test_y = test_y.cuda() # 若有cuda环境,取消注释
首先,我们下载EMNIST数据集,这里由于我们分享过手写数字的部分,这里我们按照手写字母的部分进行神经网络的训练,其split为letters,这里跟MNIST数据集不一样的地方便是多了一个split标签,备注我们需要那个分类的数据torchvision.datasets.EMNIST(root: str, split: str, **kwargs: Any)root ( string ) –数据集所在EMNIST/processed/training.pt 和 EMNIST/processed/test.pt存在的根目录 。split(字符串)-该数据集具有6个不同的拆分:byclass,bymerge,balanced,letters,digits和mnist 。此参数指定使用哪一个 。train ( bool , optional )– 如果为 True,则从 中创建数据集training.pt,否则从test.pt.download ( bool , optional ) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中 。如果数据集已经下载,则不会再次下载 。transform ( callable , optional ) – 一个函数/转换,它接收一个 PIL 图像并返回一个转换后的版本 。例如,transforms.RandomCroptarget_transform ( callable , optional ) – 一个接收目标并对其进行转换的函数/转换 。
可视化数据集然后我们可视化一下此数据集,看看此数据集什么样子def get_mApping(num, with_type='letters'):"""根据 mapping,由传入的 num 计算 UTF8 字符"""if with_type == 'byclass':if num <= 9:return chr(num + 48)# 数字elif num <= 35:return chr(num + 55)# 大写字母else:return chr(num + 61)# 小写字母elif with_type == 'letters':return chr(num + 64) + " / " + chr(num + 96)# 大写/小写字母elif with_type == 'digits':return chr(num + 96)else:return numfigure = plt.figure(figsize=(8, 8))cols, rows = 3, 3for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(train_data), size=(1,)).item()img, label = train_data[sample_idx]print(label)figure.add_subplot(rows, cols, i)plt.title(get_mapping(label))plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")plt.show()
推荐阅读
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Android开发中关于使用权限的常见错误
- go-micro的安装和使用
- 苹果|苹果SIM卡针卖26元网友反向点赞:还有949元数据线、2599元钥匙扣、3299元行李牌
- 迷迭香不适合人群,迷迭香精油使用禁忌
- 男士脱毛膏使用注意事项有哪些
- axure rp9的属性在哪?axure rp9怎么使用
- PS经典调色教程几种调色工具结合使用技巧
- 办公室电脑怎样限制U盘的使用?
- 盘点市面上主流的时序数据库
- Mac上的这些数据库管理软件,值得推荐?