Contents
java通过gRPC整合tensorflow serving(之三)——使用java调用tfserving的模型
SORRY 本来打算上周末发的,一直有事拖延了一下。
本篇是本系列的第三篇,承接前两篇
【java通过gRPC整合tensorflow serving——gRPC java入门例子】
【java通过gRPC整合tensorflow serving(之二)——安装tfserving并部署示例模型】
上一篇我们讲完了如何安装tensorflow serving,并讲了如何编写脚本导出一个训练好的checkpoint模型文件,最后将模型部署到了tensorflow serving。
本篇我们继续讲述如何使用java调用刚刚部署好的模型。
首先打开我们第一篇中建好的Java项目。
第一步通过tensorflow serving的协议定义proto文件生成java stub
开头先说一下 其实这一步并非必要,因为已经有别人编译好的现成的tensorflow serving client可以使用,比如maven搜索一下tensorflow-serving-client。 但是我觉得还是有必要从头开始完整的演示一下,好了继续。
在我们使用bazel编译tensorflow serving之后,在tensorflow serving的目录tensorflow_serving/apis中就有定义api协议的proto文件, 我们拷贝这个api目录到我们项目的src/proto目录(如果没有这个目录,请参考本系列第一篇)。
之后我们就可以gradle build项目来编译proto文件了,不出意外肯定会报错,报proto文件中的import找不到文件,所以这里我们需要根据项目结构去修改一下报错的proto文件,基本就是开头几行的import,比如原来的
1 | import "tensorflow/core/framework/tensor.proto" |
改成
1 | import "apis/tensor.proto"; |
解决所有的proto文件报错问题后,再build项目就会看到src/main下面生成了对应的java代码,然后我们把java代码拷贝到src/main/java下。
在java中实现输入数据的预处理(如果有必要的话)
通常情况下,我们training模型的时候,都会对数据做一些预处理,比如就图像处理来说,通常需要padding之类的。 对于我们这个例子,我们用的是mnist也同样需要对图像做一些处理,按灰度展开成784的一维数组。所以这一部分就是需要在java中实现这个预处理,我们使用opencv来实现即可,考虑到时间问题,我们这里就不按照之前python中的预处理来做了,简单的将图像缩放成2828后再灰度化展成1784(由于预处理的和train模型时的预处理不同 所以预测结果不准是肯定的,不过不影响我们演示java调用tensorflow serving的过程)
java预处理代码如下
1 2 3 4 5 6 7 8 9 10 | public static Mat preprocess(String filename) { Mat rawImage = Imgcodecs.imread(filename, Imgcodecs.IMREAD_GRAYSCALE); Mat mat1 = new Mat(), mat2 = new Mat(), mat3 = new Mat(); Imgproc.resize(rawImage, mat1, new Size(28, 28)); Imgproc.threshold(mat1, mat2, 128, 255, Imgproc.THRESH_BINARY | Imgproc.THRESH_OTSU); for (int i = 0; i < mat2.cols(); i++) { mat3.push_back(mat2.col(i)); } return mat3; } |
通过grpc客户端调用tensorflow serving
正如我们前两次讲过的 这里调用tensorflow serving同样支持同步和异步两种方式,我们演示同步调用。
基本过程就是实例化一个channel,然后实例化一个Stub,然后使用Model.ModelSpec描述要调用的模型的一系列信息 包括model name、version、signature name等,最后通过stub发起predict request。
代码如下
1 2 3 4 5 6 7 8 9 | System.loadLibrary(Core.NATIVE_LIBRARY_NAME); //载入opencv jni lib Mat mat = preprocess("/Users/andy/PycharmProjects/tensorflow-mnist/blog/own_8.png"); ManagedChannel channel = ManagedChannelBuilder.forAddress("10.199.206.42", 9000).usePlaintext(true).build(); PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel); Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder().setName("mnist").setVersion(Int64Value.newBuilder().setValue(1l).build()).setSignatureName("mnist").build(); Predict.PredictRequest request = Predict.PredictRequest.newBuilder().setModelSpec(modelSpec).putInputs("x", buildOneDimGrayImageTensorProto(mat)).setModelSpec(modelSpec).build(); Predict.PredictResponse predictResponse = blockingStub.predict(request); Map<String, TensorProto> tensorProtoMap = predictResponse.getOutputsMap(); System.out.println(tensorProtoMap.get("y")); |
其中的ip 端口之类的要根据实际情况替换。
这里要说一下的是我们使用的是TensorFlow Serving的Predict API 这在我们上一篇导出模型的时候有一个步骤是定义signature def map,我们当时使用的代码是
1 2 3 4 5 6 7 8 | prediction_signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={ 'x': tf.saved_model.utils.build_tensor_info(x), }, outputs={ 'y': tf.saved_model.utils.build_tensor_info(y), }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) |
可以看到method name指定的PREDICT_METHOD_NAME,也就是说我们导出的模型是支持Predict API的。 TensorFlow Serving还支持Classifiction API,从名字上我们就可以看出,后者是专门用于分类模型的,他的输入输出相比于原始的tensorproto就更高层一些,因为分类问题基本就是label、score,而Predict API就更通用一些, 这里提一下这一点。
运行java main方法 可以看到打印出了结果。
由于时间仓促,难免又漏掉或者不对的地方,欢迎大家指出问题 共同讨论。
你好,能把例子放到github上吗?能给我你的邮箱吗?非常感谢
亲啊,不知道你有没有遇到serving做预测很慢的情况,不知道什么原因。。。
你好, 我目前需要发送数据给tfserving预测, 以前用rest接口时格式是 {“instance”:[[[0,0,0,0],[0,0,0,0]]……], instances对应是一个7*7*4的数组, 现在改用grpc后, 这个数组格式要怎么构造到tensorProto类型里去呢?
我的邮箱是caufyc@126.com, 期待你的回复, 非常感谢!
大佬可以留个联系方式~
大佬可以留个联系方式么~ 有些地方有点不太懂
可以邮箱联系哈 im@andy-cheung.me