TensorFlow機器學習系統(三):以Java載入TensorFlow模型進行跨平台分類與預測作業

於以TensorFlow Python API完成之深度學習模型,透過跨平台語言Java封裝TensorFlow並讀取模型生成分類與預測結果,降低於目標平台建置系統環境之複雜度,以利於快速部屬系統。
系統環境:
  • TensorFlow 1.5 Java  API
  • Java 8
  • 皆安裝於預設安裝目錄
建立步驟:
  1. 配置Maven設定檔
    • 開發Java專案建議採用Maven,於設定檔pom.xml內加入下列依賴庫。
    •     <dependency>
              <groupid>org.tensorflow</groupid>
              <artifactid>tensorflow</artifactid>
              <version>1.5.0</version>
          </dependency>
      
          <dependency>
              <groupid>org.tensorflow</groupid>
              <artifactid>proto</artifactid>
              <version>1.5.0</version>
          </dependency>
      
          <dependency>
              <groupid>com.google.protobuf</groupid>
              <artifactid>protobuf-java</artifactid>
              <version>3.5.1</version>
          </dependency>
      
  2. 引入相關資源
    • 載入TensorFlow讀取模型所需的庫。
    • import org.tensorflow.*;
      import org.tensorflow.example.*;
      import com.google.protobuf.ByteString;
      
  3. 定義特徵資料讀取方式
    • 將要預測資料先存入 ArrayList<float> input_data 動態陣列內,透過Example.newBuilder().setFeatures函式轉為模型所需的特徵輸入資料。
    • float[] input = new float[input_data.size()];
      int i = 0;
      for (Float f : input_data) {
          input[i++] = (f != null ? f : Float.NaN);
      }
      Features features = Features.newBuilder()
          .putFeature("x", feature(input))
          .build();
      Example example = Example.newBuilder().setFeatures(features).build();
      
  4. 讀取深度學習模型並生成結果
    • 根據所訓練出來的模型輸出的分類或預測值數量,定義DEFAULT_LABELS_NUM內的數值。
    • 設定模型存放位置(見系列文章二),此處預設放在model內。
    • 由於先前所訓練的模型的模型為DNN迴歸模型,故在fetch參數內定義輸出為"dnn/logits/BiasAdd:0",並用resultArray陣列讀出結果。
    • final static int DEFAULT_LABELS_NUM = 1;
      try (SavedModelBundle model = SavedModelBundle.load("./model", "serve")) {
          Session session = model.session();
          final String xName = "input_example_tensor:0";
          final String scoresName = "dnn/logits/BiasAdd:0"; // TF1.5版
          // final String scoresName = "dnn/head/logits:0"; // TF1.4版
          float[][] resultArray;
          // 輸入模型生成結果(DNNRegressor)
          try (Tensor inputBatch = Tensors.create(new byte[][] { example.toByteArray() });
              Tensor output = session
             .runner()
             .feed(xName, inputBatch)
             .fetch(scoresName)
             .run()
             .get(0)
             .expect(Float.class)) {
                 resultArray = output.copyTo(new float[1][DEFAULT_LABELS_NUM]);
          }
          return resultArray[0];
      }
      
  5. 參考資料:
  6. 系列文章:

留言

這個網誌中的熱門文章