如何快速开始在Java中使用DJL(Deep Java Library)进行深度学习
作为个人站长或技术爱好者,如果你希望在你的公有云虚拟机或VPS上利用Java生态系统进行深度学习模型的部署或开发,Deep Java Library (DJL) 是一个出色的选择。DJL提供了一套统一的API,让你无需关心底层引擎(如PyTorch, TensorFlow, 或 MXNet),即可在JVM上运行各种模型。
本教程将引导你完成DJL的快速启动,并展示如何使用预训练模型进行简单的图像分类推理。
准备工作
- Java环境: 确保你的系统安装了JDK 8或更高版本(推荐JDK 11+)。
- 构建工具: 使用 Maven 或 Gradle。
第一步:配置项目依赖 (Maven)
我们将使用 Maven 来管理项目依赖。DJL 提供了 Bill of Materials (BOM) 来简化版本管理。同时,我们需要引入核心 API 和一个实际的深度学习引擎(这里我们选择 PyTorch)。
在你的 pom.xml 文件中添加以下配置:
<properties>
<!-- 推荐使用最新稳定版本 -->
<djl.version>0.27.0</djl.version>
</properties>
<dependencies>
<!-- 引入DJL BOM以管理版本 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<!-- DJL核心API -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<!-- 引入 PyTorch 引擎和相关的 native C++ 库(CPU版本) -->
<!-- 如果你不需要特定模型库,只使用api和引擎即可 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
</dependency>
<!-- 引入常用模型的Model Zoo,用于加载预训练模型 -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
</dependencies>
注意: pytorch-native-cpu 的
需要根据你VPS/VM的操作系统和架构进行调整(例如,Windows 使用 win-x86_64,Mac 使用 osx-x86_64)。对于常见的Linux VPS,linux-x86_64 是正确的选择。
第二步:编写图像分类代码
我们将编写一个简单的 Java 类,使用 DJL 的 Criteria 加载一个预训练的 ResNet50 模型,并对本地图像执行预测。
创建一个名为 QuickStart.java 的文件:
import ai.djl.Application;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.ModelException;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
public class QuickStart {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// 1. 设置模型加载标准
Criteria<Image, Classifications> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
// 声明输入和输出类型
.setTypes(Image.class, Classifications.class)
// 指定使用 ResNet50 模型
.optFilter("artifact_id", "resnet")
.optFilter("backbone", "resnet50")
// 确保使用 PyTorch 引擎
.optEngine(Engine.get =DefaultEngineName())
.build();
// 2. 准备输入图像
// 确保你的项目根目录下有一张名为 "test_image.jpg" 的图片。
// 在运行前,你可以下载一张图片并重命名为 test_image.jpg
Path inputPath = Paths.get("test_image.jpg");
if (!inputPath.toFile().exists()) {
System.err.println("错误: 未找到 test_image.jpg 文件。请在当前目录下放置一张图片。");
return;
}
Image image = ImageFactory.getInstance().fromFile(inputPath);
// 3. 加载模型并执行预测
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
System.out.println("模型加载完成,开始预测...");
// 首次运行时,模型会从网络下载,请耐心等待。
Classifications classifications = predictor.predict(image);
// 4. 输出结果
System.out.println("--- 图像预测结果 (Top 5) ---");
System.out.println(classifications.topK(5));
System.out.println("----------------------------");
} catch (Exception e) {
e.printStackTrace();
}
}
}
第三步:运行与验证
- 放置图像: 在项目根目录(与 pom.xml 同级或在IDE中设置的运行路径)放置一张图片,并命名为 test_image.jpg。
- 构建与运行: 如果使用Maven,可以在命令行执行:
# 编译项目
mvn clean compile
# 运行主类
mvn exec:java -Dexec.mainClass="QuickStart"
预期输出:
程序将首先下载 ResNet50 模型文件(如果本地不存在)。然后,它会输出类似以下格式的预测结果,列出图像最可能的五个分类标签及置信度:
模型加载完成,开始预测...
--- 图像预测结果 (Top 5) ---
[Classifications item: 'tiger cat', probability: 0.887]
[Classifications item: 'tabby cat', probability: 0.054]
[Classifications item: 'Egyptian cat', probability: 0.021]
[Classifications item: 'lion', probability: 0.009]
[Classifications item: 'cheetah', probability: 0.003]
----------------------------
通过以上步骤,你已经在你的 Java 环境中成功部署并运行了第一个基于 DJL 的深度学习模型推理任务。你可以将此模型推理服务集成到任何基于 JVM 的 Web 应用中,充分利用你的云主机资源。
汤不热吧