TorchScript, torch.fx, TorchDynamo, torch.compile, and torch.export
TorchScriptとtorch.fx、その周辺技術に関するメモ。
TorchScript
TorchScriptとは、
- Pythonコードをシリアライズ可能・最適化可能なモデルへ変換するための方法 (Optimizer)
- 静的型付けされたPythonサブセット言語 (Language)
- PyTorchモデルの中間表現 (IR; intermediate representation)
- ハイパフォーマンスなデプロイランタイム (Runtime)
である。
TorchScriptは独自のインタプリタでも実行できる(Pythonでも実行できる)ので、Pythonインタプリタに依存しない環境へのデプロイに使える。 例えばシリアライズして、C++ライブラリのlibtorchを使って推論もできる。
TorchScriptへの変換は、@torch.jit.script
デコレータでPythonコードをパースして構築 (Scripting) するか、torch.jit.trace
を使って実行記録から構築 (Tracing) する。
ScriptingはTorchScript言語でPythonコードを記述する必要があり、Pythonコードの修正が必要。
Tracingはcontrol-flow(if-else、for、whileなどのオペレーション)を記録しないので、これも記録する場合はScriptingを使う必要がある。
torch.nn.Module
を変換した場合はtorch.jit.ScriptModule
を得る。メソッド、attribute、パラメータ、定数を保持する。
単体の関数を変換した場合はtorch.jit.ScriptFunction
を得る。attributeやパラメータは保持しない。
Apache TVM/RelayのPyTorch Frontendは、TorchScriptのみ対応している。
torch.fx
torch.nn.Module
の書き換えを容易にするため仕組み。次の3️つの要素から構成される。
- Tracer (
torch.fx.symbolic_trace
)- PythonコードをFX IR (FX中間表現)へ変換する。
- Proxyと呼ぶ偽のインプットを入力に与え、symblic execution (シンボリック実行)により実行記録から構築する。
- symblic execution <-> eager execution。
- ただし、tracingなのでcontrol-flow(if-else、for、whileなどのオペレーション)を変換できない。これは
torch.jit.trace
の欠点と同様。
- FX IR (FX中間表現)
fx.GraphModule
:torch.nn.Module
のtorch.fx特化版。fx.Graph
: 計算グラフを表現するクラス。fx.Node
:fx.Graph
内のoperationや式を表現するクラス。operationの種類は次のとおり。placeholder
: 関数への入力を表す。get_attr
:torch.nn.Module
からパラメータを取得する。call_function
: `関数呼び出し。call_module
:torch.nn.Module
の呼び出し。call_method
: Pythonオブジェクトのメソッド呼び出し。x.view()
など。output
: 計算グラフの出力を表す。
- Code generation (コード生成)
fx.Graph
をPythonコードへ変換する。
これら3つを使うことで、[Python code] -> symbolic tracing -> [FX IR] -> transformation (IR編集) -> [optimized FX IR] -> Code generation -> [optimized Python code]のような書き換えパイプラインを構築できる。また、Pythonで実装されいているのでカスタマイズが容易、Pythonデバッガ等のPythonエコシステムを利用できる。
Apache TVM/RelaxのPyTorch Frontendはfx.GraphModuleに対応している。
TorchDynamo
TorchDynamoとは、PythonコードをFX IR (FX中間表現)へ変換するためのTracer。CPythonのFrame Evaluation API (PEP523)機能を使ってPyTorchのグラフを取得する。Pythonコードにcontrol-flowを含んでいてもtracingできるが、control-flow自体をFX IRへ変換できるわけではない。 (多分)
torch.compile
torch.compileとは、PythonコードをJITコンパイルし、高速化するための仕組み。 TorchDynamoでGraphをキャプチャし、バックエンドコンパイラ (e.g. TorchInductor) でコンパイルする。
Partial graph captureとFull graph captureの両方をサポートしている。 Partial graph captureの場合、traceできない部分を起点にGraphを分割し、その部分のみPythonで実行する。 Full graph captureだとtraceできない部分にぶつかるとエラーになる。
torch.export
Python callable (torch.nn.Module
、関数、メソッド) をExport IRに変換し、シリアライズするための仕組み。
Export IRはtorch.fx.Graph
上に表現される中間表現で、fx.Graph
としても解釈できる。
torch.export()
を実行するとtorch.export.ExportedProgram
クラスのインスタンスが生成される。ここにfx.Graph
やパラメータ、メタデータが格納される。
torch.export.save()
/torch.export.load()
でExportedProgram
をファイルに保存/読込できる。
torch.export.dynamic_shapes.Dim
でtensor shapeをdynamic shape化やshapeのrange指定などもできる。
data/shape-dependent control-flowは対応していないので、当該部分は書き換える必要がある。現状はtorch.cond
のみサポート。
References
- TorchScript — PyTorch 2.4 documentation
- TorchScript Language Reference — PyTorch 2.4 documentation
- Key points to grasp for TorchScript beginners | by Huawei Zhu | Medium
- TorchScript: Tracing vs. Scripting - Yuxin’s Blog
- torch.fx — PyTorch 2.4 documentation
- torch.fx.experimental.symbolic_shapes — PyTorch 2.4 documentation
- Dynamo Overview — PyTorch 2.4 documentation
- torch.compile — PyTorch 2.4 documentation
- torch.export — PyTorch 2.4 documentation
- torch.export IR Specification — PyTorch 2.4 documentation
- [Public] PT2 Backend Integration.ipynb - Colab