TensorFlow機器學習系統(三):以Java載入TensorFlow模型進行跨平台分類與預測作業
於以TensorFlow Python API完成之深度學習模型,透過跨平台語言Java封裝TensorFlow並讀取模型生成分類與預測結果,降低於目標平台建置系統環境之複雜度,以利於快速部屬系統。
系統環境:
系統環境:
- TensorFlow 1.5 Java API
- Java 8
- 皆安裝於預設安裝目錄
- 配置Maven設定檔
- 開發Java專案建議採用Maven,於設定檔pom.xml內加入下列依賴庫。
- 引入相關資源
- 載入TensorFlow讀取模型所需的庫。
- 定義特徵資料讀取方式
- 將要預測資料先存入
ArrayList<float> input_data 動態陣列內,透過Example.newBuilder().setFeatures函式轉為模型所需的特徵輸入資料。 - 讀取深度學習模型並生成結果
- 根據所訓練出來的模型輸出的分類或預測值數量,定義DEFAULT_LABELS_NUM內的數值。
- 設定模型存放位置(見系列文章二),此處預設放在model內。
- 由於先前所訓練的模型的模型為DNN迴歸模型,故在fetch參數內定義輸出為"dnn/logits/BiasAdd:0",並用resultArray陣列讀出結果。
- 參考資料:
- 系列文章:
<dependency>
<groupid>org.tensorflow</groupid>
<artifactid>tensorflow</artifactid>
<version>1.5.0</version>
</dependency>
<dependency>
<groupid>org.tensorflow</groupid>
<artifactid>proto</artifactid>
<version>1.5.0</version>
</dependency>
<dependency>
<groupid>com.google.protobuf</groupid>
<artifactid>protobuf-java</artifactid>
<version>3.5.1</version>
</dependency>
import org.tensorflow.*;
import org.tensorflow.example.*;
import com.google.protobuf.ByteString;
float[] input = new float[input_data.size()];
int i = 0;
for (Float f : input_data) {
input[i++] = (f != null ? f : Float.NaN);
}
Features features = Features.newBuilder()
.putFeature("x", feature(input))
.build();
Example example = Example.newBuilder().setFeatures(features).build();
final static int DEFAULT_LABELS_NUM = 1;
try (SavedModelBundle model = SavedModelBundle.load("./model", "serve")) {
Session session = model.session();
final String xName = "input_example_tensor:0";
final String scoresName = "dnn/logits/BiasAdd:0"; // TF1.5版
// final String scoresName = "dnn/head/logits:0"; // TF1.4版
float[][] resultArray;
// 輸入模型生成結果(DNNRegressor)
try (Tensor inputBatch = Tensors.create(new byte[][] { example.toByteArray() });
Tensor output = session
.runner()
.feed(xName, inputBatch)
.fetch(scoresName)
.run()
.get(0)
.expect(Float.class)) {
resultArray = output.copyTo(new float[1][DEFAULT_LABELS_NUM]);
}
return resultArray[0];
}
留言
張貼留言