用Java训练深度学习模型,原来可以这么简单!
- 1、下载文档前请自行甄别文档内容的完整性,平台不提供额外的编辑、内容补充、找答案等附加服务。
- 2、"仅部分预览"的文档,不可在线预览部分如存在完整性等问题,可反馈申请退款(可完整预览的文档不适用该条件!)。
- 3、如文档侵犯您的权益,请联系客服反馈,我们会尽快为您处理(人工客服工作时间:9:00-18:30)。
⽤Java训练深度学习模型,原来可以这么简单!
本⽂适合有 Java 基础的⼈群
作者:DJL-Keerthan&Lanking
HelloGitHub 推出的系列。
这⼀期是由亚马逊⼯程师:,为我们讲解 DJL(完全由 Java 构建的深度学习平台)系列的第 4 篇。
⼀、前⾔
很长时间以来,Java 都是⼀个很受企业欢迎的编程语⾔。
得益于丰富的⽣态以及完善维护的包和框架,Java 拥有着庞⼤的开发者社区。
尽管深度学习应⽤的不断演进和落地,提供给 Java 开发者的框架和库却⼗分短缺。
现今主要流⾏的深度学习模型都是⽤ Python 编译和训练的。
对于 Java 开发者⽽⾔,如果要进军深度学习界,就需要重新学习并接受⼀门新的编程语⾔同时还要学习深度学习的复杂知识。
这使得⼤部分 Java 开发者学习和转型深度学习开发变得困难重重。
为了减少 Java 开发者学习深度学习的成本,AWS 构建了 Deep Java Library (DJL),⼀个为 Java 开发者定制的开源深度学习框架。
它为Java 开发者对接主流深度学习框架提供了⼀个桥梁。
在这篇⽂章中,我们会尝试⽤ DJL 构建⼀个深度学习模型并⽤它训练 MNIST ⼿写数字识别任务。
⼆、什么是深度学习?
在我们正式开始之前,我们先来了解⼀下机器学习和深度学习的基本概念。
机器学习是⼀个通过利⽤统计学知识,将数据输⼊到计算机中进⾏训练并完成特定⽬标任务的过程。
这种归纳学习的⽅法可以让计算机学习⼀些特征并进⾏⼀系列复杂的任务,⽐如识别照⽚中的物体。
由于需要写复杂的逻辑以及测量标准,这些任务在传统计算科学领域中很难实现。
深度学习是机器学习的⼀个分⽀,主要侧重于对于⼈⼯神经⽹络的开发。
⼈⼯神经⽹络是通过研究⼈脑如何学习和实现⽬标的过程中归纳⽽得出⼀套计算逻辑。
它通过模拟部分⼈脑神经间信息传递的过程,从⽽实现各类复杂的任务。
深度学习中的“深度”来源于我们会在⼈⼯神经⽹络中编织构建出许多层(layer)从⽽进⼀步对数据信息进⾏更深层的传导。
深度学习技术应⽤范围⼗分⼴泛,现在被⽤来做⽬标检测、动作识别、机器翻译、语意分析等各类现实应⽤中。
三、训练 MNIST ⼿写数字识别
3.1 项⽬配置
你可以⽤如下的gradle配置来引⼊依赖项。
在这个案例中,我们⽤ DJL 的 api 包 (核⼼ DJL 组件) 和 basicdataset 包 (DJL 数据集) 来构建神经⽹络和数据集。
这个案例中我们使⽤了 MXNet 作为深度学习引擎,所以我们会引⼊mxnet-engine和mxnet-native-auto两个包。
这个案例也可以运⾏在 PyTorch 引擎下,只需要替换成对应的软件包即可。
plugins {
id 'java'
}
repositories {
jcenter()
}
dependencies {
implementation platform("ai.djl:bom:0.8.0")
implementation "ai.djl:api"
implementation "ai.djl:basicdataset"
// MXNet
runtimeOnly "ai.djl.mxnet:mxnet-engine"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto"
}
3.2 NDArray 和 NDManager
NDArray 是 DJL 存储数据结构和数学运算的基本结构。
⼀个 NDArray 表达了⼀个定长的多维数组。
NDArray 的使⽤⽅法类似于 Python 中的numpy.ndarray。
NDManager 是 NDArray 的⽼板。
它负责管理 NDArray 的产⽣和回收过程,这样可以帮助我们更好的对 Java 内存进⾏优化。
每⼀个NDArray 都会是由⼀个 NDManager 创造出来,同时它们会在 NDManager 关闭时⼀同关闭。
NDManager 和 NDArray 都是由 Java 的AutoClosable 构建,这样可以确保在运⾏结束时及时进⾏回收。
想了解更多关于它们的⽤法和实践,请参阅我们前⼀期⽂章:
Model
在 DJL 中,训练和推理都是从 Model class 开始构建的。
我们在这⾥主要讲训练过程中的构建⽅法。
下⾯我们为 Model 创建⼀个新的⽬标。
因为 Model 也是继承了 AutoClosable 结构体,我们会⽤⼀个 try block 实现:
try (Model model = Model.newInstance()) {
...
// 主体训练代码
...
}
准备数据
MNIST(Modified National Institute of Standards and Technology)数据库包含⼤量⼿写数字的图,通常被⽤来训练图像处理系统。
DJL 已经将 MNIST 的数据集收录到了 basicdataset 数据集⾥,每个 MNIST 的图的⼤⼩是28 x 28。
如果你有⾃⼰的数据集,你也可以通过 DJL 数据集导⼊教程来导⼊数据集到你的训练任务中。
int batchSize = 32; // 批⼤⼩
Mnist trainingDataset = Mnist.builder()
.optUsage(Usage.TRAIN) // 训练集
.setSampling(batchSize, true)
.build();
Mnist validationDataset = Mnist.builder()
.optUsage(Usage.TEST) // 验证集
.setSampling(batchSize, true)
.build();
这段代码分别制作出了训练和验证集。
同时我们也随机排列了数据集从⽽更好的训练。
除了这些配置以外,你也可以添加对于图⽚的进⼀步处理,⽐如设置图⽚⼤⼩,对图⽚进⾏归⼀化等处理。
制作 model(建⽴ Block)
当你的数据集准备就绪后,我们就可以构建神经⽹络了。
在 DJL 中,神经⽹络是由 Block(代码块)构成的。
⼀个 Block 是⼀个具备多种神经⽹络特性的结构。
它们可以代表⼀个操作, 神经⽹络的⼀部分,甚⾄是⼀个完整的神经⽹络。
然后 Block 可以顺序执⾏或者并⾏。
同时Block 本⾝也可以带参数和⼦ Block。
这种嵌套结构可以帮助我们构造⼀个复杂但⼜不失维护性的神经⽹络。
在训练过程中,每个 Block 中附带的参数会被实时更新,同时也包括它们的各个⼦ Block。
这种递归更新的过程可以确保整个神经⽹络得到充分训练。
当我们构建这些 Block 的过程中,最简单的⽅式就是将它们⼀个⼀个的嵌套起来。
直接使⽤准备好 DJL 的 Block 种类,我们就可以快速制作出各类神经⽹络。
根据⼏种基本的神经⽹络⼯作模式,我们提供了⼏种 Block 的变体。
SequentialBlock 是为了应对顺序执⾏每⼀个⼦ Block 构造⽽成的。
它会将前⼀个⼦ Block 的输出作为下⼀个 Block 的输⼊继续执⾏到底。
与之对应的,是 ParallelBlock 它⽤于将⼀个输⼊并⾏输⼊到每⼀个⼦Block 中,同时将输出结果根据特定的合并⽅程合并起来。
最后我们说⼀下 LambdaBlock,它是帮助⽤户进⾏快速操作的⼀个 Block,其中并不具备任何参数,所以也没有任何部分在训练过程中更新。
我们来尝试创建⼀个基本的多层感知机(MLP)神经⽹络吧。
多层感知机是⼀个简单的前向型神经⽹络,它只包含了⼏个全连接层(LinearBlock)。
那么构建这个⽹络,我们可以直接使⽤ SequentialBlock。
int input = 28 * 28; // 输⼊层⼤⼩
int output = 10; // 输出层⼤⼩
int[] hidden = new int[] {128, 64}; // 隐藏层⼤⼩
SequentialBlock sequentialBlock = new SequentialBlock();
sequentialBlock.add(Blocks.batchFlattenBlock(input));
for (int hiddenSize : hidden) {
// 全连接层
sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());
// 激活函数
sequentialBlock.add(activation);
}
sequentialBlock.add(Linear.builder().setUnits(output).build());
当然 DJL 也提供了直接就可以拿来⽤的 MLP Block :
Block block = new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[] {128, 64});
训练
当我们准备好数据集和神经⽹络之后,就可以开始训练模型了。
在深度学习中,⼀般会由下⾯⼏步来完成⼀个训练过程:
初始化:我们会对每⼀个 Block 的参数进⾏初始化,初始化每个参数的函数都是由设定的 Initializer 决定的。
前向传播:这⼀步将输⼊数据在神经⽹络中逐层传递,然后产⽣输出数据。
计算损失:我们会根据特定的损失函数 Loss 来计算输出和标记结果的偏差。
反向传播:在这⼀步中,你可以利⽤损失反向求导算出每⼀个参数的梯度。
更新权重:我们会根据选择的优化器(Optimizer)更新每⼀个在 Block 上参数的值。
DJL 利⽤了 Trainer 结构体精简了整个过程。
开发者只需要创建 Trainer 并指定对应的 Initializer、Loss 和 Optimizer 即可。
这些参数都是由TrainingConfig 设定的。
下⾯我们来看⼀下具体的参数设置:
TrainingListener:这个是对训练过程设定的监听器。
它可以实时反馈每个阶段的训练结果。
这些结果可以⽤于记录训练过程或者帮助debug 神经⽹络训练过程中的问题。
⽤户也可以定制⾃⼰的 TrainingListener 来对训练过程进⾏监听。
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
try (Trainer trainer = model.newTrainer(config)){
// 训练代码
}
当训练器产⽣后,我们可以定义输⼊的 Shape。
之后就可以调⽤ fit 函数来进⾏训练。
fit 函数会对输⼊数据,训练多个 epoch 是并最终将结果存储在本地⽬录下。
/*
* MNIST 包含 28x28 灰度图⽚并导⼊成 28 * 28 NDArray。
* 第⼀个维度是批⼤⼩, 在这⾥我们设置批⼤⼩为 1 ⽤于初始化。
*/
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
int numEpoch = 5;
String outputDir = "/build/model";
// ⽤输⼊初始化 trainer
trainer.initialize(inputShape);
TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");
这就是训练过程的全部流程了!⽤ DJL 训练是不是还是很轻松的?之后看⼀下输出每⼀步的训练结果。
如果你⽤了我们默认的监听器,那么输出是类似于下图:
[INFO ] - Downloading libmxnet.dylib ...
[INFO ] - Training on: cpu().
[INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec Validating: 100% |████████████████████████████████████████|
[INFO ] - Epoch 1 finished.
[INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24
[INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec Validating: 100% |████████████████████████████████████████|
[INFO ] - Epoch 2 finished.NG [1m 41s]
[INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10
[INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
[INFO ] - train P50: 12.756 ms, P90: 21.044 ms
[INFO ] - forward P50: 0.375 ms, P90: 0.607 ms
[INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms
[INFO ] - backward P50: 0.608 ms, P90: 0.973 ms
[INFO ] - step P50: 0.543 ms, P90: 0.869 ms
[INFO ] - epoch P50: 35.989 s, P90: 35.989 s
当训练结果完成后,我们可以⽤刚才的模型进⾏推理来识别⼿写数字。
如果刚才的内容哪⾥有不是很清楚的,可以参照下⾯两个链接直接尝试训练。
四、最后
在这个⽂章中,我们介绍了深度学习的基本概念,同时还有如何优雅的利⽤ DJL 构建深度学习模型并进⾏训练。
DJL 也提供了更加多样的数据集和神经⽹络。
如果有兴趣学习深度学习,可以参阅我们的 Java 深度学习书。
Deep Java Library(DJL)是⼀个基于 Java 的深度学习框架,同时⽀持训练以及推理。
DJL 博取众长,构建在多个深度学习框架之上(TenserFlow、PyTorch、MXNet 等) 也同时具备多个框架的优良特性。
你可以轻松使⽤ DJL 来进⾏训练然后部署你的模型。
它同时拥有着强⼤的模型库⽀持:只需⼀⾏便可以轻松读取各种预训练的模型。
现在 DJL 的模型库同时⽀持⾼达 70 多个来⾃ GluonCV、HuggingFace、TorchHub 以及 Keras 的模型。