コンテンツにスキップ

カスタム演算子

カスタム演算子チュートリアル

Section titled “カスタム演算子チュートリアル”

このチュートリアルでは、ONNX Runtimeでカスタム演算子を作成して使用する方法を示します。

  • ONNX RuntimeがC++でビルドされていること
  • C++コンパイラがインストールされていること

カスタム演算子はC++で実装されます。次のコードは、2つのテンソルを追加する単純なカスタム演算子の例です。

#include <onnxruntime_cxx_api.h>
struct CustomOp {
CustomOp(const OrtApi& api) : api_(api) {}
void Compute(OrtKernelContext* context) {
// 入力を取得
const OrtValue* input_x = api_.KernelContext_GetInput(context, 0);
const OrtValue* input_y = api_.KernelContext_GetInput(context, 1);
const float* x = api_.GetTensorData<float>(input_x);
const float* y = api_.GetTensorData<float>(input_y);
// 出力形状を取得
OrtTensorDimensions dimensions(api_, input_x);
// 出力を作成
OrtValue* output = api_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
float* out = api_.GetTensorMutableData<float>(output);
// 計算を実行
for (size_t i = 0; i < dimensions.Size(); ++i) {
out[i] = x[i] + y[i];
}
}
private:
const OrtApi& api_;
};

このコードは、2つの入力テンソルを取得し、それらを要素ごとに追加して、結果を出力テンソルに書き込むComputeメソッドを実装しています。

カスタム演算子カーネルの作成

Section titled “カスタム演算子カーネルの作成”

カスタム演算子カーネルは、カスタム演算子をONNX Runtimeに登録します。

struct CustomOpKernel {
CustomOpKernel(const OrtApi& api) : op_(api) {}
void Compute(OrtKernelContext* context) {
op_.Compute(context);
}
private:
CustomOp op_;
};
CustomOpKernel* CreateKernel(const OrtApi& api, const OrtKernelInfo* /*info*/) {
return new CustomOpKernel(api);
}
void ReleaseKernel(CustomOpKernel* kernel) {
delete kernel;
}

カスタム演算子スキーマの作成

Section titled “カスタム演算子スキーマの作成”

カスタム演算子スキーマは、カスタム演算子の入力と出力を定義します。

const char* GetName() { return "CustomOp"; }
const char* GetExecutionProviderType() { return "CPUExecutionProvider"; }
ONNXTensorElementDataType GetInputType(size_t /*index*/) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
size_t GetInputCount() { return 2; }
ONNXTensorElementDataType GetOutputType(size_t /*index*/) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
size_t GetOutputCount() { return 1; }

カスタム演算子は、RegisterCustomOpsLibrary関数を使用してONNX Runtimeに登録されます。

OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApi* api) {
OrtCustomOpDomain* domain = nullptr;
api->CreateCustomOpDomain("custom.op", &domain);
OrtCustomOp op;
op.version = 1;
op.CreateKernel = (void*)CreateKernel;
op.GetName = (void*)GetName;
op.GetExecutionProviderType = (void*)GetExecutionProviderType;
op.GetInputType = (void*)GetInputType;
op.GetInputCount = (void*)GetInputCount;
op.GetOutputType = (void*)GetOutputType;
op.GetOutputCount = (void*)GetOutputCount;
op.KernelDestroy = (void*)ReleaseKernel;
api->CustomOpDomain_Add(domain, &op);
return api->AddCustomOpDomain(options, domain);
}

カスタム演算子は、ONNXモデルの他の演算子と同じように使用できます。

import onnxruntime
import numpy as np
# セッションオプションを作成
so = onnxruntime.SessionOptions()
# カスタム演算子ライブラリを登録
so.register_custom_ops_library("custom_op_library.so")
# ONNXモデルをロード
session = onnxruntime.InferenceSession("model.onnx", so)
# 入力を作成
x = np.array([[1, 2], [3, 4]]).astype(np.float32)
y = np.array([[5, 6], [7, 8]]).astype(np.float32)
# 推論を実行
result = session.run(None, {"X": x, "Y": y})
# 結果を出力
print(result[0])

このスクリプトは、カスタム演算子ライブラリを登録し、カスタム演算子を使用するONNXモデルをロードし、推論を実行します。