用 Java 训练深度学习模型,原来能这么简单

作者:DJL-Keerthan&Lanking
【用 Java 训练深度学习模型,原来能这么简单】 
HelloGitHub 推出的《讲解开源项目》 系列 。这一期是由亚马逊工程师:Keerthan Vasist,为我们讲解 DJL(完全由 JAVA 构建的深度学习平台)系列的第 4 篇 。
一、前言很长时间以来,Java 都是一个很受企业欢迎的编程语言 。得益于丰富的生态以及完善维护的包和框架,Java 拥有着庞大的开发者社区 。尽管深度学习应用的不断演进和落地,提供给 Java 开发者的框架和库却十分短缺 。现今主要流行的深度学习模型都是用 Python 编译和训练的 。对于 Java 开发者而言,如果要进军深度学习界,就需要重新学习并接受一门新的编程语言同时还要学习深度学习的复杂知识 。这使得大部分 Java 开发者学习和转型深度学习开发变得困难重重 。
 
为了减少 Java 开发者学习深度学习的成本,AWS 构建了 Deep Java Library (DJL),一个为 Java 开发者定制的开源深度学习框架 。它为 Java 开发者对接主流深度学习框架提供了一个桥梁 。

用 Java 训练深度学习模型,原来能这么简单

文章插图
 
在这篇文章中,我们会尝试用 DJL 构建一个深度学习模型并用它训练 MNIST 手写数字识别任务 。
二、什么是深度学习?在我们正式开始之前,我们先来了解一下机器学习和深度学习的基本概念 。
 
机器学习是一个通过利用统计学知识,将数据输入到计算机中进行训练并完成特定目标任务的过程 。这种归纳学习的方法可以让计算机学习一些特征并进行一系列复杂的任务,比如识别照片中的物体 。由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现 。
 
深度学习是机器学习的一个分支,主要侧重于对于人工神经网络的开发 。人工神经网络是通过研究人脑如何学习和实现目标的过程中归纳而得出一套计算逻辑 。它通过模拟部分人脑神经间信息传递的过程,从而实现各类复杂的任务 。深度学习中的“深度”来源于我们会在人工神经网络中编织构建出许多层(layer)从而进一步对数据信息进行更深层的传导 。深度学习技术应用范围十分广泛,现在被用来做目标检测、动作识别、机器翻译、语意分析等各类现实应用中 。
三、训练 MNIST 手写数字识别3.1 项目配置你可以用如下的 gradle 配置来引入依赖项 。在这个案例中,我们用 DJL 的 api 包 (核心 DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经网络和数据集 。这个案例中我们使用了 MXNet 作为深度学习引擎,所以我们会引入 mxnet-engine 和 mxnet-native-auto 两个包 。这个案例也可以运行在 PyTorch 引擎下,只需要替换成对应的软件包即可 。
plugins {    id 'java'}repositories {                               jcenter()}dependencies {    implementation platform("ai.djl:bom:0.8.0")    implementation "ai.djl:api"    implementation "ai.djl:basicdataset"    // MXNet    runtimeOnly "ai.djl.mxnet:mxnet-engine"    runtimeOnly "ai.djl.mxnet:mxnet-native-auto"}3.2 NDArray 和 NDManagerNDArray 是 DJL 存储数据结构和数学运算的基本结构 。一个 NDArray 表达了一个定长的多维数组 。NDArray 的使用方法类似于 Python 中的 numpy.ndarray 。
 
NDManager 是 NDArray 的老板 。它负责管理 NDArray 的产生和回收过程,这样可以帮助我们更好的对 Java 内存进行优化 。每一个 NDArray 都会是由一个 NDManager 创造出来,同时它们会在 NDManager 关闭时一同关闭 。
 
NDManager 和 NDArray 都是由 Java 的 AutoClosable 构建,这样可以确保在运行结束时及时进行回收 。想了解更多关于它们的用法和实践,请参阅我们前一期文章:DJL 之 Java 玩转多维数组,就像 NumPy 一样
 
Model在 DJL 中,训练和推理都是从 Model class 开始构建的 。我们在这里主要讲训练过程中的构建方法 。下面我们为 Model 创建一个新的目标 。因为 Model 也是继承了 AutoClosable 结构体,我们会用一个 try block 实现:
try (Model model = Model.newInstance()) {    ...    // 主体训练代码    ...}


推荐阅读