88

Java和Python使用Grpc访问Tensorflow的Serving代码

 4 years ago
source link: https://www.tuicool.com/articles/vIBJzub
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

发现网上大量的代码都是mnist,我自己反正不是搞图像处理的,所以这个例子我怎幺都不想搞;

wide&deep这种,包含各种特征的模型,才是我的需要,iris也是从文本训练模型,所以非常简单;

本文给出Python和Java访问Tensorflow的Serving代码。

Java版本使用Grpc访问Tensorflow的Serving代码

package io.github.qf6101.tensorflowserving;
 
import com.google.protobuf.ByteString;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder;
import org.tensorflow.example.*;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
 
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
 
/**
 * 参考:https://www.jianshu.com/p/d82107165119
 * 参考:https://github.com/grpc/grpc-java
 */
public class PssIrisGrpcClient {
 
    public static Example createExample() {
        Features.Builder featuresBuilder = Features.newBuilder();
 
        Map<String, Float> dataMap = new HashMap<String, Float>();
        dataMap.put("SepalLength", 5.1f);
        dataMap.put("SepalWidth", 3.3f);
        dataMap.put("PetalLength", 1.7f);
        dataMap.put("PetalWidth", 0.5f);
 
        Map<String, Feature> featuresMap = mapToFeatureMap(dataMap);
        featuresBuilder.putAllFeature(featuresMap);
 
        Features features = featuresBuilder.build();
        Example.Builder exampleBuilder = Example.newBuilder();
        exampleBuilder.setFeatures(features);
        return exampleBuilder.build();
    }
 
    private static Map<String, Feature> mapToFeatureMap(Map<String, Float> dataMap) {
        Map<String, Feature> resultMap = new HashMap<String, Feature>();
        for (String key : dataMap.keySet()) {
            // // data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
            FloatList floatList = FloatList.newBuilder().addValue(dataMap.get(key)).build();
            Feature feature = Feature.newBuilder().setFloatList(floatList).build();
            resultMap.put(key, feature);
        }
        return resultMap;
    }
 
    public static void main(String[] args) {
        String host = "127.0.0.1";
        int port = 8888;
 
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port)
                // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid
                // needing certificates.
                .usePlaintext()
                .build();
        PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel);
 
        com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder()
                .setValue(1)
                .build();
 
        Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
                .setName("iris")
                .setVersion(version)
                .setSignatureName("classification")
                .build();
 
        List<ByteString> exampleList = new ArrayList<ByteString>();
        exampleList.add(createExample().toByteString());
 
        TensorShapeProto.Dim featureDim = TensorShapeProto.Dim.newBuilder().setSize(exampleList.size()).build();
        TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(featureDim).build();
        org.tensorflow.framework.TensorProto tensorProto = TensorProto.newBuilder().addAllStringVal(exampleList).setDtype(DataType.DT_STRING).setTensorShape(shapeProto).build();
 
        Predict.PredictRequest request = Predict.PredictRequest.newBuilder()
                .setModelSpec(modelSpec)
                .putInputs("inputs", tensorProto)
                .build();
        tensorflow.serving.Predict.PredictResponse response = blockingStub.predict(request);
        System.out.println(response);
 
        channel.shutdown();
    }
}

需要增加如下maven依赖:

        <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.12.0</version>
        </dependency>
 
        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-netty -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-netty</artifactId>
            <version>1.20.0</version>
        </dependency>
 
        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-protobuf -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-protobuf</artifactId>
            <version>1.20.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/io.grpc/grpc-stub -->
        <dependency>
            <groupId>io.grpc</groupId>
            <artifactId>grpc-stub</artifactId>
            <version>1.20.0</version>
        </dependency>

输出结果:

outputs {
  key: "scores"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 1
      }
      dim {
        size: 3
      }
    }
    float_val: 0.9997806
    float_val: 2.1938368E-4
    float_val: 1.382611E-9
  }
}
outputs {
  key: "classes"
  value {
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: 1
      }
      dim {
        size: 3
      }
    }
    string_val: "0"
    string_val: "1"
    string_val: "2"
  }
}

Python版本使用Grpc访问Tensorflow的Serving代码

# 创建 gRPC 连接
import pandas as pd
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import prediction_service_pb2, classification_pb2
 
#channel = implementations.insecure_channel('127.0.0.1', 8500):8888
channel = implementations.insecure_channel('127.0.0.1', 8888)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
 
def _create_feature(v):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[v]))
 
data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5}
features1 = {k: _create_feature(v) for k, v in data1.items()}
example1 = tf.train.Example(features=tf.train.Features(feature=features1))
 
 
data2 = {"SepalLength":1.1,"SepalWidth":1.3,"PetalLength":1.7,"PetalWidth":0.5}
features2 = {k: _create_feature(v) for k, v in data2.items()}
example2 = tf.train.Example(features=tf.train.Features(feature=features2))
 
# 获取测试数据集,并转换成 Example 实例。
examples = [example1, example2]
 
# 准备 RPC 请求,指定模型名称。
request = classification_pb2.ClassificationRequest()
request.model_spec.name = 'iris'
request.input.example_list.examples.extend(examples)
 
# 获取结果
response = stub.Classify(request, 10.0)
print(response)

Python代码看起来简单不少,但是我们的线上服务都是Java,所以不好集成的,只能做一些离线的批量预测;

输出如下:

result {
  classifications {
    classes {
      label: "0"
      score: 0.9997805953025818
    }
    classes {
      label: "1"
      score: 0.00021938368445262313
    }
    classes {
      label: "2"
      score: 1.382611025668723e-09
    }
  }
  classifications {
    classes {
      label: "0"
      score: 0.0736534595489502
    }
    classes {
      label: "1"
      score: 0.8393719792366028
    }
    classes {
      label: "2"
      score: 0.08697459846735
    }
  }
}
model_spec {
  name: "iris"
  version {
    value: 1
  }
  signature_name: "serving_default"
}

个人其实非常喜欢HTTP+JSON接口,完全不用搞这幺多grpc这些麻烦的东西,尤其Java的grpc,遇到好多问题好崩溃;

不过号称grpc比http性能好不少,线上只能用grpc。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK