カスタム演算子
カスタム演算子チュートリアル
Section titled “カスタム演算子チュートリアル”このチュートリアルでは、ONNX Runtimeでカスタム演算子を作成して使用する方法を示します。
- ONNX RuntimeがC++でビルドされていること
- C++コンパイラがインストールされていること
カスタム演算子の作成
Section titled “カスタム演算子の作成”カスタム演算子はC++で実装されます。次のコードは、2つのテンソルを追加する単純なカスタム演算子の例です。
#include <onnxruntime_cxx_api.h>
struct CustomOp { CustomOp(const OrtApi& api) : api_(api) {}
void Compute(OrtKernelContext* context) { // 入力を取得 const OrtValue* input_x = api_.KernelContext_GetInput(context, 0); const OrtValue* input_y = api_.KernelContext_GetInput(context, 1); const float* x = api_.GetTensorData<float>(input_x); const float* y = api_.GetTensorData<float>(input_y);
// 出力形状を取得 OrtTensorDimensions dimensions(api_, input_x);
// 出力を作成 OrtValue* output = api_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); float* out = api_.GetTensorMutableData<float>(output);
// 計算を実行 for (size_t i = 0; i < dimensions.Size(); ++i) { out[i] = x[i] + y[i]; } }
private: const OrtApi& api_;};このコードは、2つの入力テンソルを取得し、それらを要素ごとに追加して、結果を出力テンソルに書き込むComputeメソッドを実装しています。
カスタム演算子カーネルの作成
Section titled “カスタム演算子カーネルの作成”カスタム演算子カーネルは、カスタム演算子をONNX Runtimeに登録します。
struct CustomOpKernel { CustomOpKernel(const OrtApi& api) : op_(api) {}
void Compute(OrtKernelContext* context) { op_.Compute(context); }
private: CustomOp op_;};
CustomOpKernel* CreateKernel(const OrtApi& api, const OrtKernelInfo* /*info*/) { return new CustomOpKernel(api);}
void ReleaseKernel(CustomOpKernel* kernel) { delete kernel;}カスタム演算子スキーマの作成
Section titled “カスタム演算子スキーマの作成”カスタム演算子スキーマは、カスタム演算子の入力と出力を定義します。
const char* GetName() { return "CustomOp"; }const char* GetExecutionProviderType() { return "CPUExecutionProvider"; }
ONNXTensorElementDataType GetInputType(size_t /*index*/) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }size_t GetInputCount() { return 2; }
ONNXTensorElementDataType GetOutputType(size_t /*index*/) { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }size_t GetOutputCount() { return 1; }カスタム演算子の登録
Section titled “カスタム演算子の登録”カスタム演算子は、RegisterCustomOpsLibrary関数を使用してONNX Runtimeに登録されます。
OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApi* api) { OrtCustomOpDomain* domain = nullptr; api->CreateCustomOpDomain("custom.op", &domain);
OrtCustomOp op; op.version = 1; op.CreateKernel = (void*)CreateKernel; op.GetName = (void*)GetName; op.GetExecutionProviderType = (void*)GetExecutionProviderType; op.GetInputType = (void*)GetInputType; op.GetInputCount = (void*)GetInputCount; op.GetOutputType = (void*)GetOutputType; op.GetOutputCount = (void*)GetOutputCount; op.KernelDestroy = (void*)ReleaseKernel;
api->CustomOpDomain_Add(domain, &op); return api->AddCustomOpDomain(options, domain);}カスタム演算子の使用
Section titled “カスタム演算子の使用”カスタム演算子は、ONNXモデルの他の演算子と同じように使用できます。
import onnxruntimeimport numpy as np
# セッションオプションを作成so = onnxruntime.SessionOptions()
# カスタム演算子ライブラリを登録so.register_custom_ops_library("custom_op_library.so")
# ONNXモデルをロードsession = onnxruntime.InferenceSession("model.onnx", so)
# 入力を作成x = np.array([[1, 2], [3, 4]]).astype(np.float32)y = np.array([[5, 6], [7, 8]]).astype(np.float32)
# 推論を実行result = session.run(None, {"X": x, "Y": y})
# 結果を出力print(result[0])このスクリプトは、カスタム演算子ライブラリを登録し、カスタム演算子を使用するONNXモデルをロードし、推論を実行します。