欢迎光临
我们一直在努力

java通过gRPC整合tensorflow serving(之三)——使用java调用tfserving的模型

java通过gRPC整合tensorflow serving(之三)——使用java调用tfserving的模型

SORRY 本来打算上周末发的,一直有事拖延了一下。

本篇是本系列的第三篇,承接前两篇
java通过gRPC整合tensorflow serving——gRPC java入门例子
java通过gRPC整合tensorflow serving(之二)——安装tfserving并部署示例模型

上一篇我们讲完了如何安装tensorflow serving,并讲了如何编写脚本导出一个训练好的checkpoint模型文件,最后将模型部署到了tensorflow serving。

本篇我们继续讲述如何使用java调用刚刚部署好的模型。
首先打开我们第一篇中建好的Java项目。

第一步通过tensorflow serving的协议定义proto文件生成java stub

开头先说一下 其实这一步并非必要,因为已经有别人编译好的现成的tensorflow serving client可以使用,比如maven搜索一下tensorflow-serving-client。 但是我觉得还是有必要从头开始完整的演示一下,好了继续。

在我们使用bazel编译tensorflow serving之后,在tensorflow serving的目录tensorflow_serving/apis中就有定义api协议的proto文件, 我们拷贝这个api目录到我们项目的src/proto目录(如果没有这个目录,请参考本系列第一篇)。

之后我们就可以gradle build项目来编译proto文件了,不出意外肯定会报错,报proto文件中的import找不到文件,所以这里我们需要根据项目结构去修改一下报错的proto文件,基本就是开头几行的import,比如原来的

1
import "tensorflow/core/framework/tensor.proto"

改成

1
import "apis/tensor.proto";

解决所有的proto文件报错问题后,再build项目就会看到src/main下面生成了对应的java代码,然后我们把java代码拷贝到src/main/java下。

在java中实现输入数据的预处理(如果有必要的话)

通常情况下,我们training模型的时候,都会对数据做一些预处理,比如就图像处理来说,通常需要padding之类的。 对于我们这个例子,我们用的是mnist也同样需要对图像做一些处理,按灰度展开成784的一维数组。所以这一部分就是需要在java中实现这个预处理,我们使用opencv来实现即可,考虑到时间问题,我们这里就不按照之前python中的预处理来做了,简单的将图像缩放成2828后再灰度化展成1784(由于预处理的和train模型时的预处理不同 所以预测结果不准是肯定的,不过不影响我们演示java调用tensorflow serving的过程)
java预处理代码如下

1
2
3
4
5
6
7
8
9
10
public static Mat preprocess(String filename) {
        Mat rawImage = Imgcodecs.imread(filename, Imgcodecs.IMREAD_GRAYSCALE);
        Mat mat1 = new Mat(), mat2 = new Mat(), mat3 = new Mat();
        Imgproc.resize(rawImage, mat1, new Size(28, 28));
        Imgproc.threshold(mat1, mat2, 128, 255, Imgproc.THRESH_BINARY | Imgproc.THRESH_OTSU);
        for (int i = 0; i < mat2.cols(); i++) {
            mat3.push_back(mat2.col(i));
        }
        return mat3;
    }

通过grpc客户端调用tensorflow serving

正如我们前两次讲过的 这里调用tensorflow serving同样支持同步和异步两种方式,我们演示同步调用。
基本过程就是实例化一个channel,然后实例化一个Stub,然后使用Model.ModelSpec描述要调用的模型的一系列信息 包括model name、version、signature name等,最后通过stub发起predict request。

代码如下

1
2
3
4
5
6
7
8
9
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME); //载入opencv jni lib
        Mat mat = preprocess("/Users/andy/PycharmProjects/tensorflow-mnist/blog/own_8.png");
        ManagedChannel channel = ManagedChannelBuilder.forAddress("10.199.206.42", 9000).usePlaintext(true).build();
        PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel);
        Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder().setName("mnist").setVersion(Int64Value.newBuilder().setValue(1l).build()).setSignatureName("mnist").build();
        Predict.PredictRequest request = Predict.PredictRequest.newBuilder().setModelSpec(modelSpec).putInputs("x", buildOneDimGrayImageTensorProto(mat)).setModelSpec(modelSpec).build();
        Predict.PredictResponse predictResponse = blockingStub.predict(request);
        Map<String, TensorProto> tensorProtoMap = predictResponse.getOutputsMap();
        System.out.println(tensorProtoMap.get("y"));

其中的ip 端口之类的要根据实际情况替换。
这里要说一下的是我们使用的是TensorFlow Serving的Predict API 这在我们上一篇导出模型的时候有一个步骤是定义signature def map,我们当时使用的代码是

1
2
3
4
5
6
7
8
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
        inputs={
            'x': tf.saved_model.utils.build_tensor_info(x),
        },
        outputs={
            'y': tf.saved_model.utils.build_tensor_info(y),
        },
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

可以看到method name指定的PREDICT_METHOD_NAME,也就是说我们导出的模型是支持Predict API的。 TensorFlow Serving还支持Classifiction API,从名字上我们就可以看出,后者是专门用于分类模型的,他的输入输出相比于原始的tensorproto就更高层一些,因为分类问题基本就是label、score,而Predict API就更通用一些, 这里提一下这一点。

运行java main方法 可以看到打印出了结果。

由于时间仓促,难免又漏掉或者不对的地方,欢迎大家指出问题 共同讨论。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » java通过gRPC整合tensorflow serving(之三)——使用java调用tfserving的模型
分享到: 更多 (0)

评论 6

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
  1. #1

    你好,能把例子放到github上吗?能给我你的邮箱吗?非常感谢

    eclipseme7年前 (2018-03-20)回复
  2. #2

    亲啊,不知道你有没有遇到serving做预测很慢的情况,不知道什么原因。。。

    scotter7年前 (2018-03-21)回复
  3. #3

    你好, 我目前需要发送数据给tfserving预测, 以前用rest接口时格式是 {“instance”:[[[0,0,0,0],[0,0,0,0]]……], instances对应是一个7*7*4的数组, 现在改用grpc后, 这个数组格式要怎么构造到tensorProto类型里去呢?
    我的邮箱是caufyc@126.com, 期待你的回复, 非常感谢!

    forest1235年前 (2019-09-06)回复
  4. #4

    大佬可以留个联系方式~

    w5年前 (2020-02-06)回复
  5. #5

    大佬可以留个联系方式么~ 有些地方有点不太懂

    w5年前 (2020-02-06)回复