コンテンツにスキップ

モバイルでの画像解像度の向上

モバイルでの機械学習超解像による画像解像度の向上

Section titled “モバイルでの機械学習超解像による画像解像度の向上”

前処理と後処理を含むモデルを使用して、ONNX Runtime Mobileを使用して画像解像度を向上させるアプリケーションを構築する方法を学びます。

このチュートリアルを使用して、AndroidまたはiOS用のアプリケーションを構築できます。

アプリケーションは画像入力を取得し、ボタンがクリックされると超解像操作を実行し、以下のスクリーンショットのように、解像度が向上した画像を下に表示します。

猫の超解像

このチュートリアルで使用される機械学習モデルは、このページの下部で参照されているPyTorchチュートリアルで使用されているモデルに基づいています。

PyTorchモデルをONNX形式にエクスポートし、前処理と後処理を追加する便利なPythonスクリプトを提供します。

  1. このスクリプトを実行する前に、以下のpythonパッケージをインストールします:

    Terminal window
    pip install torch
    pip install pillow
    pip install onnx
    pip install onnxruntime
    pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions

    バージョンに関する注意:最良の超解像結果は、ONNX opset 18(アンチエイリアシングを備えたResize演算子のサポート)で達成されます。これはonnx 1.13.0とonnxruntime 1.14.0以降でサポートされています。onnxruntime-extensionsパッケージはプレリリースバージョンです。リリースバージョンはまもなく利用可能になります。

  2. 次に、onnxruntime-extensions GitHubリポジトリからスクリプトとテスト画像をダウンロードします(このリポジトリをまだクローンしていない場合):

    Terminal window
    curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py
    curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
  3. スクリプトを実行してコアモデルをエクスポートし、前処理と後処理を追加します

    Terminal window
    python superresolution_e2e.py

スクリプトが実行された後、スクリプトを実行した場所のフォルダに2つのONNXファイルが表示されるはずです:

Terminal window
pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx

2つのモデルをnetronに読み込むと、2つの間の入力と出力の違いがわかります。下の最初の2つの画像は、入力がチャネルデータのバッチである元のモデルを示しており、次の2つは入力と出力が画像バイトであることを示しています。

前処理と後処理なしのONNXモデル

前処理と後処理なしのONNXモデルの入力と出力

前処理と後処理ありのONNXモデル

前処理と後処理ありのONNXモデルの入力と出力

アプリケーションコードを作成する時が来ました。

  • Android Studio Dolphin 2021.3.1 Patch +(Mac/Windows/Linuxにインストール)
  • Android SDK 29+
  • Android NDK r22+
  • AndroidデバイスまたはAndroidエミュレーター

GitHubでAndroid超解像アプリの完全なソースコードを見つけることができます。

ソースコードからアプリを実行するには、上記のリポジトリをクローンし、build.gradleファイルをAndroid studioに読み込み、ビルドして実行します!

アプリを段階的にビルドするには、以下のセクションに従ってください。

Android studioでPhone and Tablet用の新しいプロジェクトを作成し、空白テンプレートを選択します。アプリケーションをsuper_resolutionまたは類似の名前にします。

アプリのbuild.gradleに次の依存関係を追加します:

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
  1. モデルファイルを生リソースとして追加する

    src/main/resフォルダーにrawというフォルダーを作成し、ONNXモデルをrawフォルダーに移動またはコピーします。

  2. テスト画像をアセットとして追加する

    メインプロジェクトフォルダーにassetsというフォルダーを作成し、超解像を実行する画像をtest_superresolution.pngというファイル名でそのフォルダーにコピーします

メインアプリケーションクラスコード

Section titled “メインアプリケーションクラスコード”

MainActivity.ktというファイルを作成し、以下のコードを追加します。

  1. インポートステートメントを追加する

    import ai.onnxruntime.*
    import ai.onnxruntime.extensions.OrtxPackage
    import android.annotation.SuppressLint
    import android.os.Bundle
    import android.widget.Button
    import android.widget.ImageView
    import android.widget.Toast
    import androidx.activity.*
    import androidx.appcompat.app.AppCompatActivity
    import kotlinx.android.synthetic.main.activity_main.*
    import kotlinx.coroutines.*
    import java.io.InputStream
    import java.util.*
    import java.util.concurrent.ExecutorService
    import java.util.concurrent.Executors
  2. メインアクティビティクラスを作成し、クラス変数を追加する

    class MainActivity : AppCompatActivity() {
    private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
    private lateinit var ortSession: OrtSession
    private var inputImage: ImageView? = null
    private var outputImage: ImageView? = null
    private var superResolutionButton: Button? = null
    ...
    }
  3. onCreate()メソッドを追加する

    ここでONNX Runtimeセッションを初期化します。セッションは、アプリケーションで推論を実行するために使用されるモデルへの参照を保持します。また、セッションオプションパラメータを受け取り、ここで異なる実行プロバイダー(NNAPIなどのハードウェアアクセラレータ)を指定できます。この場合、デフォルトでCPUで実行します。ただし、モデルの入力と出力にある画像エンコーディングおよびデコーディング演算子が見つかるカスタム演算子ライブラリを登録します。

    override fun onCreate(savedInstanceState: Bundle?) {
    super.onCreate(savedInstanceState)
    setContentView(R.layout.activity_main)
    inputImage = findViewById(R.id.imageView1)
    outputImage = findViewById(R.id.imageView2);
    superResolutionButton = findViewById(R.id.super_resolution_button)
    inputImage?.setImageBitmap(
    BitmapFactory.decodeStream(readInputImage())
    );
    // Ortセッションを初期化し、カスタム演算子を含むonnxruntime拡張パッケージを登録します。
    // 注:これらは入力画像を元のモデルが要求する形式にデコードし、
    // モデル出力をpng形式にエンコードするために使用されます
    val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()
    sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
    ortSession = ortEnv.createSession(readModel(), sessionOptions)
    superResolutionButton?.setOnClickListener {
    try {
    performSuperResolution(ortSession)
    Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT)
    .show()
    } catch (e: Exception) {
    Log.e(TAG, "Exception caught when perform super resolution", e)
    Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT)
    .show()
    }
    }
    }
  4. onDestroyメソッドを追加する

    override fun onDestroy() {
    super.onDestroy()
    ortEnv.close()
    ortSession.close()
    }
  5. updateUIメソッドを追加する

    private fun updateUI(result: Result) {
    outputImage?.setImageBitmap(result.outputBitmap)
    }
  6. readModelメソッドを追加する

    このメソッドは、リソースフォルダーからONNXモデルを読み取ります。

    private fun readModel(): ByteArray {
    val modelID = R.pytorch_superresolution_with_pre_post_processing_op18
    return resources.openRawResource(modelID).readBytes()
    }
  7. 入力画像を読み取るメソッドを追加する

    このメソッドは、アセットフォルダーからテスト画像を読み取ります。現在、アプリケーションに組み込まれている固定画像を読み取ります。サンプルはまもなく拡張され、カメラまたはカメラロールから直接画像を読み取るようになります。

    private fun readInputImage(): InputStream {
    return assets.open("test_superresolution.png")
    }
  8. 推論を実行するメソッドを追加する

    このメソッドは、アプリケーションの中核であるメソッドSuperResPerformer.upscale()を呼び出します。これは、モデルで推論を実行するメソッドです。このコードは次のセクションに示されています。

    private fun performSuperResolution(ortSession: OrtSession) {
    var superResPerformer = SuperResPerformer()
    var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession)
    updateUI(result);
    }
  9. TAGオブジェクトを追加する

    companion object {
    const val TAG = "ORTSuperResolution"
    }

SuperResPerformer.ktというファイルを作成し、以下のコードスニペットを追加します。

  1. インポートを追加する

    import ai.onnxruntime.OnnxJavaType
    import ai.onnxruntime.OrtSession
    import ai.onnxruntime.OnnxTensor
    import ai.onnxruntime.OrtEnvironment
    import android.graphics.Bitmap
    import android.graphics.BitmapFactory
    import java.io.InputStream
    import java.nio.ByteBuffer
    import java.util.*
  2. 結果クラスを作成する

    internal data class Result(
    var outputBitmap: Bitmap? = null
    ) {}
  3. 超解像パフォーマークラスを作成する

    このクラスとそのメイン関数upscaleは、ONNX Runtimeへの呼び出しの大部分が存在する場所です。

    • OrtEnvironmentシングルトンは、環境のプロパティと設定されたログレベルを維持します
    • OnnxTensor.createTensor()は、入力画像バイトで構成されるテンソルを作成するために使用され、モデルへの入力として適切です
    • OnnxJavaType.UINT8は、入力テンソルのByteBufferのデータ型です
    • OrtSession.run()は、モデルで推論(予測)を実行して、出力のアップスケールされた画像を取得します
    internal class SuperResPerformer(
    ) {
    fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {
    var result = Result()
    // ステップ1: 画像をバイト配列(生画像バイト)に変換
    val rawImageBytes = inputStream.readBytes()
    // ステップ2: バイト配列の形状を取得し、ortテンソルを作成
    val shape = longArrayOf(rawImageBytes.size.toLong())
    val inputTensor = OnnxTensor.createTensor(
    ortEnv,
    ByteBuffer.wrap(rawImageBytes),
    shape,
    OnnxJavaType.UINT8
    )
    inputTensor.use {
    // ステップ3: ort推論セッション実行を呼び出す
    val output = ortSession.run(Collections.singletonMap("image", inputTensor))
    // ステップ4: 出力分析
    output.use {
    val rawOutput = (output?.get(0)?.value) as ByteArray
    val outputImageBitmap =
    byteArrayToBitmap(rawOutput)
    // ステップ5: 出力結果を設定
    result.outputBitmap = outputImageBitmap
    }
    }
    return result
    }

Android studio内で:

  • Build -> Make Projectを選択
  • Run -> app

アプリはデバイスエミュレーターで実行されます。Androidデバイスに接続して、デバイスでアプリを実行します。

  • Xcode 13.0以上をインストール(できれば最新バージョン)
  • iOSデバイスまたはiOSシミュレーター
  • Xcodeコマンドラインツールxcode-select --install
  • CocoaPods sudo gem install cocoapods
  • 有効なApple Developer ID(デバイスで実行する予定の場合)

GitHubでiOS超解像アプリの完全なソースコードを見つけることができます。

ソースコードからアプリを実行するには:

  1. onnxruntime-inference-examplesリポジトリをクローンする

    Terminal window
    git clone https://github.com/microsoft/onnxruntime-inference-examples
    cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
  2. 必要なpodファイルをインストールする

    Terminal window
    pod install
  3. XCodeで生成されたORTSuperResolution.xcworkspaceファイルを開く

    (オプション:デバイスで実行する場合にのみ必要)開発チームを選択

  4. アプリケーションを実行する

    iOSデバイスまたはシミュレーターを接続し、アプリをビルドして実行します

    Perform Super Resolutionボタンをクリックして、アプリの動作を確認します

アプリを段階的に開発するには、以下のセクションに従ってください。

APPテンプレートを使用してXCodeで新しいプロジェクトを作成します

以下のpodsをインストールします:

Terminal window
# OrtSuperResolution用のPods
pod 'onnxruntime-c'
# プレリリースバージョンのpods
pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'
  1. モデルファイルをプロジェクトに追加する

    このチュートリアルの最初に生成されたモデルファイルをプロジェクトフォルダーのルートにコピーします。

  2. テスト画像をアセットとして追加する

    超解像を実行する画像をプロジェクトフォルダーのルートにコピーします。

ORTSuperResolutionApp.swiftというファイルを開き、以下のコードを追加します:

import SwiftUI
@main
struct ORTSuperResolutionApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}

ContentView.swiftというファイルを開き、以下のコードを追加します:

import SwiftUI
struct ContentView: View {
@State private var performSuperRes = false
func runOrtSuperResolution() -> UIImage? {
do {
let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
return outputImage
} catch let error as NSError {
print("Error: \(error.localizedDescription)")
return nil
}
}
var body: some View {
ScrollView {
VStack {
VStack {
Text("ORTSuperResolution").font(.title).bold()
.frame(width: 400, height: 80)
.border(Color.purple, width: 4)
.background(Color.purple)
Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
Image("cat_224x224").frame(width: 250, height: 250)
Button("Perform Super Resolution") {
performSuperRes.toggle()
}
if performSuperRes {
Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
if let outputImage = runOrtSuperResolution() {
Image(uiImage: outputImage)
} else {
Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
}
}
Spacer()
}
}
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}

Swift / Objective Cブリッジヘッダー

Section titled “Swift / Objective Cブリッジヘッダー”

ORTSuperResolution-Bridging-Header.hというファイルを作成し、以下のインポートステートメントを追加します:

#import "ORTSuperResolutionPerformer.h"
  1. ORTSuperResolutionPerformer.hというファイルを作成し、以下のコードを追加します:

    #ifndef ORTSuperResolutionPerformer_h
    #define ORTSuperResolutionPerformer_h
    #import <Foundation/Foundation.h>
    #import <UIKit/UIKit.h>
    NS_ASSUME_NONNULL_BEGIN
    @interface ORTSuperResolutionPerformer : NSObject
    + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error;
    @end
    NS_ASSUME_NONNULL_END
    #endif
  2. ORTSuperResolutionPerformer.mmというファイルを作成し、以下のコードを追加します:

    #import "ORTSuperResolutionPerformer.h"
    #import <Foundation/Foundation.h>
    #import <UIKit/UIKit.h>
    #include <array>
    #include <cstdint>
    #include <stdexcept>
    #include <string>
    #include <vector>
    #include <onnxruntime_cxx_api.h>
    #include <onnxruntime_extensions.h>
    @implementation ORTSuperResolutionPerformer
    + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error {
    UIImage* output_image = nil;
    try {
    // カスタム演算子を登録
    const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
    auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution");
    auto session_options = Ort::SessionOptions();
    if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) {
    throw std::runtime_error("RegisterCustomOps failed");
    }
    // ステップ1: モデルを読み込む
    NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16"
    ofType:@"onnx"];
    if (model_path == nullptr) {
    throw std::runtime_error("Failed to get model path");
    }
    // ステップ2: Ort推論セッションを作成
    auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options);
    // 入力画像を読み込む
    // 注:PNGファイルを混乱させないようにXcode設定を設定する必要があります:
    // "Build Settings"で:
    // - "Compress PNG Files"を"No"に設定
    // - "Remove Text Metadata From PNG Files"を"No"に設定
    NSString *input_image_path =
    [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"];
    if (input_image_path == nullptr) {
    throw std::runtime_error("Failed to get image path");
    }
    // ステップ3: 入力テンソルと入出力名を準備
    NSMutableData *input_data =
    [NSMutableData dataWithContentsOfFile:input_image_path];
    const int64_t input_data_length = input_data.length;
    const auto memoryInfo =
    Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length,
    &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
    constexpr auto input_names = std::array{"image"};
    constexpr auto output_names = std::array{"image_out"};
    // ステップ4: 推論セッション実行を呼び出す
    const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(),
    &input_tensor, 1, output_names.data(), 1);
    if (outputs.size() != 1) {
    throw std::runtime_error("Unexpected number of outputs");
    }
    // ステップ5: モデル出力を分析
    const auto &output_tensor = outputs.front();
    const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo();
    const auto output_shape = output_type_and_shape_info.GetShape();
    if (const auto output_element_type =
    output_type_and_shape_info.GetElementType();
    output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
    throw std::runtime_error("Unexpected output element type");
    }
    const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
    // ステップ6: 生バイトをNSDataに変換し、表示可能なUIImageとして返す
    NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])];
    output_image = [UIImage imageWithData:output_data];
    } catch (std::exception &e) {
    NSLog(@"%s error: %s", __FUNCTION__, e.what());
    static NSString *const kErrorDomain = @"ORTSuperResolution";
    constexpr NSInteger kErrorCode = 0;
    if (error) {
    NSString *description =
    [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding];
    *error =
    [NSError errorWithDomain:kErrorDomain
    code:kErrorCode
    userInfo:@{NSLocalizedDescriptionKey : description}];
    }
    return nullptr;
    }
    if (error) {
    *error = nullptr;
    }
    return output_image;
    }
    @end

XCodeで、三角形のビルドアイコンを選択してアプリをビルドして実行します!

元のPyTorchチュートリアル