欢迎光临
我们一直在努力

【djl】【deep java library】java深度学习教程1——快速开始

如何快速开始在Java中使用DJL(Deep Java Library)进行深度学习

作为个人站长或技术爱好者,如果你希望在你的公有云虚拟机或VPS上利用Java生态系统进行深度学习模型的部署或开发,Deep Java Library (DJL) 是一个出色的选择。DJL提供了一套统一的API,让你无需关心底层引擎(如PyTorch, TensorFlow, 或 MXNet),即可在JVM上运行各种模型。

本教程将引导你完成DJL的快速启动,并展示如何使用预训练模型进行简单的图像分类推理。

准备工作

  1. Java环境: 确保你的系统安装了JDK 8或更高版本(推荐JDK 11+)。
  2. 构建工具: 使用 Maven 或 Gradle。

第一步:配置项目依赖 (Maven)

我们将使用 Maven 来管理项目依赖。DJL 提供了 Bill of Materials (BOM) 来简化版本管理。同时,我们需要引入核心 API 和一个实际的深度学习引擎(这里我们选择 PyTorch)。

在你的 pom.xml 文件中添加以下配置:

<properties>
    <!-- 推荐使用最新稳定版本 -->
    <djl.version>0.27.0</djl.version>
</properties>

<dependencies>
    <!-- 引入DJL BOM以管理版本 -->
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>bom</artifactId>
        <version>${djl.version}</version>
        <type>pom</type>
        <scope>import</scope>
    </dependency>

    <!-- DJL核心API -->
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>api</artifactId>
    </dependency>

    <!-- 引入 PyTorch 引擎和相关的 native C++ 库(CPU版本) -->
    <!-- 如果你不需要特定模型库,只使用api和引擎即可 -->
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-engine</artifactId>
    </dependency>
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-native-cpu</artifactId>
        <classifier>linux-x86_64</classifier>
    </dependency>

    <!-- 引入常用模型的Model Zoo,用于加载预训练模型 -->
    <dependency>
        <groupId>ai.djl</groupId>
        <artifactId>model-zoo</artifactId>
    </dependency>
</dependencies>

注意: pytorch-native-cpu 需要根据你VPS/VM的操作系统和架构进行调整(例如,Windows 使用 win-x86_64,Mac 使用 osx-x86_64)。对于常见的Linux VPS,linux-x86_64 是正确的选择。

第二步:编写图像分类代码

我们将编写一个简单的 Java 类,使用 DJL 的 Criteria 加载一个预训练的 ResNet50 模型,并对本地图像执行预测。

创建一个名为 QuickStart.java 的文件:

import ai.djl.Application;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.ModelException;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

public class QuickStart {

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        // 1. 设置模型加载标准
        Criteria<Image, Classifications> criteria = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                // 声明输入和输出类型
                .setTypes(Image.class, Classifications.class)
                // 指定使用 ResNet50 模型
                .optFilter("artifact_id", "resnet")
                .optFilter("backbone", "resnet50")
                // 确保使用 PyTorch 引擎
                .optEngine(Engine.get =DefaultEngineName())
                .build();

        // 2. 准备输入图像
        // 确保你的项目根目录下有一张名为 "test_image.jpg" 的图片。
        // 在运行前,你可以下载一张图片并重命名为 test_image.jpg
        Path inputPath = Paths.get("test_image.jpg");
        if (!inputPath.toFile().exists()) {
            System.err.println("错误: 未找到 test_image.jpg 文件。请在当前目录下放置一张图片。");
            return;
        }
        Image image = ImageFactory.getInstance().fromFile(inputPath);

        // 3. 加载模型并执行预测
        try (ZooModel<Image, Classifications> model = criteria.loadModel();
             Predictor<Image, Classifications> predictor = model.newPredictor()) {

            System.out.println("模型加载完成,开始预测...");

            // 首次运行时,模型会从网络下载,请耐心等待。
            Classifications classifications = predictor.predict(image);

            // 4. 输出结果
            System.out.println("--- 图像预测结果 (Top 5) ---");
            System.out.println(classifications.topK(5));
            System.out.println("----------------------------");

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

第三步:运行与验证

  1. 放置图像: 在项目根目录(与 pom.xml 同级或在IDE中设置的运行路径)放置一张图片,并命名为 test_image.jpg
  2. 构建与运行: 如果使用Maven,可以在命令行执行:
# 编译项目
mvn clean compile

# 运行主类
mvn exec:java -Dexec.mainClass="QuickStart"

预期输出:

程序将首先下载 ResNet50 模型文件(如果本地不存在)。然后,它会输出类似以下格式的预测结果,列出图像最可能的五个分类标签及置信度:

模型加载完成,开始预测...
--- 图像预测结果 (Top 5) ---
[Classifications item: 'tiger cat', probability: 0.887]
[Classifications item: 'tabby cat', probability: 0.054]
[Classifications item: 'Egyptian cat', probability: 0.021]
[Classifications item: 'lion', probability: 0.009]
[Classifications item: 'cheetah', probability: 0.003]
----------------------------

通过以上步骤,你已经在你的 Java 环境中成功部署并运行了第一个基于 DJL 的深度学习模型推理任务。你可以将此模型推理服务集成到任何基于 JVM 的 Web 应用中,充分利用你的云主机资源。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 【djl】【deep java library】java深度学习教程1——快速开始
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址