TorchScript, torch.fx, TorchDynamo, torch.compile, and torch.export


TorchScriptとtorch.fx、その周辺技術に関するメモ。

TorchScript

TorchScriptとは、

  1. Pythonコードをシリアライズ可能・最適化可能なモデルへ変換するための方法 (Optimizer)
  2. 静的型付けされたPythonサブセット言語 (Language)
  3. PyTorchモデルの中間表現 (IR; intermediate representation)
  4. ハイパフォーマンスなデプロイランタイム (Runtime)

である。

TorchScriptは独自のインタプリタでも実行できる(Pythonでも実行できる)ので、Pythonインタプリタに依存しない環境へのデプロイに使える。 例えばシリアライズして、C++ライブラリのlibtorchを使って推論もできる。

TorchScriptへの変換は、@torch.jit.scriptデコレータでPythonコードをパースして構築 (Scripting) するか、torch.jit.traceを使って実行記録から構築 (Tracing) する。 ScriptingはTorchScript言語でPythonコードを記述する必要があり、Pythonコードの修正が必要。 Tracingはcontrol-flow(if-else、for、whileなどのオペレーション)を記録しないので、これも記録する場合はScriptingを使う必要がある。

torch.nn.Moduleを変換した場合はtorch.jit.ScriptModuleを得る。メソッド、attribute、パラメータ、定数を保持する。 単体の関数を変換した場合はtorch.jit.ScriptFunctionを得る。attributeやパラメータは保持しない。

Apache TVM/RelayのPyTorch Frontendは、TorchScriptのみ対応している。

torch.fx

torch.nn.Moduleの書き換えを容易にするため仕組み。次の3️つの要素から構成される。

  • Tracer (torch.fx.symbolic_trace)
    • PythonコードをFX IR (FX中間表現)へ変換する。
    • Proxyと呼ぶ偽のインプットを入力に与え、symblic execution (シンボリック実行)により実行記録から構築する。
      • symblic execution <-> eager execution。
    • ただし、tracingなのでcontrol-flow(if-else、for、whileなどのオペレーション)を変換できない。これはtorch.jit.traceの欠点と同様。
  • FX IR (FX中間表現)
    • fx.GraphModule: torch.nn.Moduleのtorch.fx特化版。
    • fx.Graph: 計算グラフを表現するクラス。
    • fx.Node: fx.Graph内のoperationや式を表現するクラス。operationの種類は次のとおり。
      • placeholder: 関数への入力を表す。
      • get_attr: torch.nn.Moduleからパラメータを取得する。
      • call_function: `関数呼び出し。
      • call_module: torch.nn.Moduleの呼び出し。
      • call_method: Pythonオブジェクトのメソッド呼び出し。x.view()など。
      • output: 計算グラフの出力を表す。
  • Code generation (コード生成)
    • fx.GraphをPythonコードへ変換する。

これら3つを使うことで、[Python code] -> symbolic tracing -> [FX IR] -> transformation (IR編集) -> [optimized FX IR] -> Code generation -> [optimized Python code]のような書き換えパイプラインを構築できる。また、Pythonで実装されいているのでカスタマイズが容易、Pythonデバッガ等のPythonエコシステムを利用できる。

Apache TVM/RelaxのPyTorch Frontendはfx.GraphModuleに対応している。

TorchDynamo

TorchDynamoとは、PythonコードをFX IR (FX中間表現)へ変換するためのTracer。CPythonのFrame Evaluation API (PEP523)機能を使ってPyTorchのグラフを取得する。Pythonコードにcontrol-flowを含んでいてもtracingできるが、control-flow自体をFX IRへ変換できるわけではない。 (多分)

torch.compile

torch.compileとは、PythonコードをJITコンパイルし、高速化するための仕組み。 TorchDynamoでGraphをキャプチャし、バックエンドコンパイラ (e.g. TorchInductor) でコンパイルする。

Partial graph captureとFull graph captureの両方をサポートしている。 Partial graph captureの場合、traceできない部分を起点にGraphを分割し、その部分のみPythonで実行する。 Full graph captureだとtraceできない部分にぶつかるとエラーになる。

torch.export

Python callable (torch.nn.Module、関数、メソッド) をExport IRに変換し、シリアライズするための仕組み。 Export IRはtorch.fx.Graph上に表現される中間表現で、fx.Graphとしても解釈できる。

torch.export()を実行するとtorch.export.ExportedProgramクラスのインスタンスが生成される。ここにfx.Graphやパラメータ、メタデータが格納される。 torch.export.save()/torch.export.load()ExportedProgramをファイルに保存/読込できる。 torch.export.dynamic_shapes.Dimでtensor shapeをdynamic shape化やshapeのrange指定などもできる。

data/shape-dependent control-flowは対応していないので、当該部分は書き換える必要がある。現状はtorch.condのみサポート。

References