コンテンツにスキップ

大規模モデルのトレーニング

ORTModuleによる大規模モデルのトレーニング入門

Section titled “ORTModuleによる大規模モデルのトレーニング入門”

ONNX Runtime TrainingORTModuleは、PyTorchフロントエンドを使用して定義されたモデル向けの高性能なトレーニングエンジンを提供します。ORTModuleは、モデル定義を変更することなく、トレーニングスクリプト全体に1行のコード変更(ORTModuleラップ)だけで大規模モデルのトレーニングを高速化するように設計されています。

ORTModuleクラスラッパーを使用すると、ONNX Runtimeは最適化された自動エクスポートされたONNX計算グラフを使用して、トレーニングスクリプトのフォワードパスとバックワードパスを実行します。

この例では、PyTorchでモデルをトレーニングするためにORTを使用する方法について説明します。

Terminal window
# torch_ortおよびonnxruntime-training Pythonパッケージをインストールします
pip install torch-ort
# ユーザーのPyTorchインストールで動作するようにonnxruntime-trainingを設定します
python -m torch_ort.configure

: これにより、特定のバージョンのCUDAライブラリにマッピングされているtorch-ortおよびonnxruntime-trainingパッケージのデフォルトバージョンがインストールされます。onnxruntime.aiのインストールオプションを参照してください。

  • train.pyにORTModuleを追加します
from torch_ort import ORTModule
.
.
.
model = build_model() # ユーザーのPyTorchモデル
model = ORTModule(build_model())

ONNX Runtimeトレーニングの例