竹笋

首页 » 问答 » 环境 » 一种深度学习框架的介绍及总结
TUhjnbcbe - 2023/4/22 19:32:00

目前TensorFlow、Keras等深度学习框架被广泛使用,所以,我今天要介绍另一个框架Deeplearning4J。

它是一个基于纯Java的深度学习框架,同时也是Java的第一大深度学习框架,目前在Github上已经有9k多颗星了。它运行在纯Java环境,能够很好的应用在生产环境中,且能够较好的兼容TensorFlow、Keras等主流框架训练的模型。

下面我将从几个方面介绍DL4J。

1.Deeplearning4J可以解决哪些问题?

Deeplearning4J基于Nd4j(可以看成是java版的numpy),即底层由JavaCPP实现,通过把正向反向传播的过程向量化,调用基于CPU或者GPU的并行计算库(OpenBlas、MKL、CuBlas等现行代数运算库),实现矩阵操作的并行计算。

Deeplearning4J提供两种方式的网络构建方法,一种是MultiLayerNetwork,层与层之间是一层层具有先后顺序地堆叠起来的;另一种是ComputationGraph,即手工构建计算图的网络;使用MultiLayerNetwork方式构建网络更加方便;使用ComputationGraph的方式构建的网络更加的灵活。两种方式分别对应Keras里面的Sequential和Model。

其实深度神经网络从本质上来讲和浅层神经网络并无区别,Deeplearning4J实现了主流神经网络中的神经网络单元、结构、优化方法、以及一些Trick。你可以构建逻辑回归、深度神经网络、CNN网络、RNN网络,做单分类、序列分类、多类标分类、多标签分类,也可以做回归;可以构建简单的Seq2seq模型,也可以构建VAE(变分编码器)这种复杂的encoder-decoder;可以实现MultiInput的网络,也可以实现MultiChannel的网络,还可以做MultiTaskLearning;可以做fine-tuning,也可以transferlearning。总之能够较好的应用在任何需要实现AI功能的应用场景中。

更多的Deeplearning4J的介绍和使用就直接Google吧。

2..Deeplearning4J在NLP中应用。

Deeplearning4J用于NLP领域——使用SkipGram+NegativeSampling进行wordembedding模型的训练和使用,用于文本快速相似度计算;基于Static、Non-static、以及Hybrid的embedding作为输入构建CNN+RNN的深度神经网络模型,用于文本分类任务;Deeplearning4J的整体架构还是设计的不错,且包含大部分Paper中的Trick,都在我们的模型训练中有应用,如Dropout、BatchNormalization等,并产生了良好的效果,有效防止了过拟合的产生。模型训练时,我们使用了Hyperparameter的自动调节,通过交叉验证进行自动结果评估,自动选取和优化模型超参数,让算法开发人员能够花更多的心思在数据的处理上(个人认为,超参数的设定对模型结果的影响不是主要的)。Deeplearning4J提供了一个类似于TensorBoard的module,能够让算法开发人员实施监控到损失的score(即loss),同时也使用了EarlyStop策略,及时阻止过拟合的发生。

3.若已有TensorFlow或者Keras训练好的模型了,怎么在现有系统中使用?怎么和Deeplearning4J结合?

Deeplearning4J提供model-import模块。可以导入Keras保存的模型文件。而Keras又以TensorFlow、Theano、Caffe、Torch四大框架为Backend。于是就有了下面这个架构:

4.如果想分布式训练,怎么部署?

Deeplearning4J作为Java中的第一大深度学习库,一大好处就是可以对接Spark,做分布式计算。即使通过CPU,也能取得FLOP。

5.优缺点介绍

优点:

(1)性能优异。基于并行库加速,在CPU上做预测,也能快速响应;

(2)Deeplearning4J的发展一直处于上升期。0.7一直更新到1.0.0-beta1,均有良好表现;

(3)能够复用Keras、TF、Torch、Theano训练的模型;

(4)和Spark无缝衔接。基于Java的附加红利就是支持Scala,于是可以和Spark衔接,做分布式训练。

缺点:

(1)比起TF、PyTorch这些网红,Deeplearning4J还是显得小众,文档支持较少,生态较为单薄,且功能没那么强大。

(2)缺少一些最前沿的方法的支持,比如需要自己实现自定义的Layer来实现一些牛逼特性和算法,如Attention、NCRF等;

(3)BiDirectionalLSTM实现有bug1.0.0-beta2版本优化方法有bug

1
查看完整版本: 一种深度学习框架的介绍及总结