DirectML使ってみた

冬は寒いのでDNNの学習を回すのにぴったり!GPUの廃熱で暖房費節約だぜ!などと思ったけれど、メインマシンはメインOSはWindowsで運用していてビデオカードAMDなのでDNNフレームワークを動かすのはしんどい。 調べたらDirectMLってのでWindows+AMDでもいけそうじゃん!ということで試しに使ってみた。
結論としてはGPU使用率があまり高くならず暖かくならなかった

環境構築(共通)

環境はAnacodaを使わずにWindowsPythonをインストールしてvenvで構築する。Anacondaはライセンス変わっちゃったからね

Pythonのインストール

python.orgからインストーラをダウンロードする。自分はこちらのページから3.8.10の「Windows installer (64-bit)」を選んだ。

インストール時にはpipが一緒にインストールされるようにオプションがOnになっていることを確認した。(自分の環境ではデフォルトでOnになっていた)

あと、Pythonのパッケージを入れる際に git.exe と cl.exe も必要になるのでインストールしてパスに追加する。

gitは「Git for Windows」を入れたような気がするが、ずいぶん前の事なので詳細は不明。

cl.exeはMicrosoft C++ Build Toolsページから「Build Tools のダウンロード」ボタンを押してインストーラを取得、インストーラからMicrosoft Visual C++だか何だかを選択して入れたような気がする。(こちらもうろ覚え)

venv

Pythonをインストールした時点でvenvも入っているのでそのままvenv環境を作れる。
構築直後はpipのバージョンが古いのでバージョンアップしておく。この時、直接pipコマンドでバージョンアップしようとすると環境を壊してしまうので注意。python -m pipでバージョンアップする。

>python -m venv env_top_dir
>env_top_dir\Scripts\activate
>python -m pip install --upgrade pip

ONNXRuntime

環境構築

venv環境にonnxとonnxruntime-directmlパッケージをインストールすればOK。

>pip install onnx onnxruntime-directml

お試し

基本的にはDirectMLのExecution Provider(DmlExecutionProvider)を指定するだけだが、2点注意点がある。

  1. opsetバージョンはv17まで
  2. セッションのオプションでenable_mem_patternを無効化しておく必要がある

どちらもDirectML版ONNXRuntimeが対応してないっぽい。ちなみにenable_mem_patternの方は無効化しなくても以下の警告が表示されて自動的に無効化されるっぽい。

[W:onnxruntime:, inference_session.cc:491 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.

以下お試しコード。モデルはConv1個だけのなんちゃってモデル。

import onnx
import onnx.numpy_helper
import numpy as np
import onnxruntime as ort


# Conv1個だけのモデル
inputs  = [onnx.helper.make_tensor_value_info('input' , onnx.TensorProto.FLOAT, [1, 3, 4, 4])]
outputs = [onnx.helper.make_tensor_value_info('output', onnx.TensorProto.FLOAT, [1, 1, 4, 4])]
nodes   = [onnx.helper.make_node('Conv', ['input', 'weight'], ['output'])]
inits   = [onnx.numpy_helper.from_array(1.0 / 4.0 * np.ones([1, 3, 1, 1], dtype=np.float32), 'weight')]
model = onnx.helper.make_model(onnx.helper.make_graph(nodes, 'conv', inputs, outputs, inits), opset_imports=[onnx.helper.make_opsetid('', 17)])

# Onだと警告が出るのであらかじめOff設定を入れておく
options = ort.SessionOptions()
options.enable_mem_pattern = False

# ExecutionProviderにDML版を指定して実行する
sess = ort.InferenceSession(model.SerializeToString(), options, ['DmlExecutionProvider', 'CPUExecutionProvider'])
sess.run(None, {'input': np.ones([1, 3, 4, 4], dtype=np.float32)})

TensorFlow

環境構築

venv環境にpipで入れるだけ。ほかに必要なパッケージは依存パッケージとして自動で入った。

>pip install tensorflow-directml-plugin

お試し

これで普通に動いた。'/job:localhost/replica:0/task:0/device:GPU:0'などと表示されたのでたぶん動いてる。

import tensorflow as tf
a = tf.constant([1.5])
b = tf.constant([0.5])
(a + b).device

あと公式のサンプルをそのまま書かれている通りに実行してみたら普通に動いた。データセットのダウンロードも自動で実行してくれてとても楽だった。

PyTorch

環境構築

同じくvenv環境にpipで入れる。

>pip install torchvision==0.14.0
>pip install torch==1.13
>pip install torch-directml

お試し

torch.deviceをDirectMLのもので指定すれば良いらしい。Tensor.to()には文字列を指定できずtorch.deviceを渡す必要がある。

あと、torch.Tensorをrepr()などで表示しようとするとエラーになる。(CPUに転送すれば表示できる)

import torch
import torch_directml


dml = torch_directml.device()

a = torch.tensor([1.5]).to(dml)
b = torch.tensor([0.5]).to(dml)
c = a + b
c.to('cpu')

簡単なモデルを作って動かしてみたがConv2d、BatchNorm2d、ReLU、Linearあたりは普通に動きそうだった。

mmdetectionでDETRの学習

PyTorchで動く物体検出向けフレームワーク?のMMDetectionを使ってDETR実装で学習を回すところまで改造してみた。

結論を先に言っておくとGPU使用率は上がらず温まらなかったtouch.deviceを入れ替えるだけでは動かなかった。

まだまだCPU実行時と同じ動きをしてくれないオペレーションがあるので既存のフレームワークなんかをそのまま使うのは厳しい、ということが分かった。
今はまだ公式のサンプルを使うのがよさそうに思える。サンプルのyolov3を試そうとしたらデータセットのダウンロード方法がよくわからず面倒になってやめてしまった

環境構築

このあたりを参考にしつつ以下の手順でvenv環境にインストールした。

>pip install mmcv-full==1.7.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.13/index.html
>git clone https://github.com/open-mmlab/mmdetection.git
>cd mmdetection
>pip install -v -e .
>pip install opencv-python

※DirectMLで動かすためにgit cloneしたmmdetectionリポジトリソースコードを改造して無理やり動かしている

さらにDETRの定義ファイルと重みデータをダウンロードする。

>pip install -U openmim
>mim download mmdet --config yolov3_mobilenetv2_320_300e_coco --dest checkpoints

データセットのダウンロード方法を見ながらMS COCOデータセットを用意して、学習の実行方法を参考にした。

最終的にはデータセットの置き場所をEドライブのdatasetsに変更していたので以下の感じで実行した。

>set "MMDET_DATASETS=E:/datasets/coco/"
>python source_packages\mmdetection\tools\train.py checkpoints\detr_r50_8x2_150e_coco.py --cfg-options data.samples_per_gpu=4

samples_per_gpuは1枚のビデオカードで一度に読み出すデータ数らしくてビデオカードが1枚しか存在しない環境ならそのままバッチサイズになるらしい。(たぶん。↑だとバッチサイズ4ということ)

困ったこと

DirectMLで動かそうとして遭遇したことは以下の通り。

  • VRAMが足りなくなるとブルースクリーンでOSごと落ちる(正確にはPCが再起動する)
  • DirectMLが対応していないオペレーションがある
    • エラーになるケース(Pythonの例外が送出される)とエラーにならず実行結果がCPU実行時と異なるケースの2パターンある
    • どちらのケースも該当箇所の処理をCPUデバイスで実行するようにすればとりあえず動くようになる

DirectMLが対応していなかった箇所(DETRで通過する箇所のみ)

mmdet/core/bbox/match_costs/match_cost.py

torch.cdist()で例外になる。

RuntimeError: The size of tensor a (2) must match the size of tensor b (100) at non-singleton dimension 0

CPU実行時はバッチ次元が異なっても問題なく実行できるがDirectML実行時はエラーになる。

@@ -47,8 +47,8 @@ class BBoxL1Cost:
             gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
         elif self.box_format == 'xyxy':
             bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
-        bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
-        return bbox_cost * self.weight
+        bbox_cost = torch.cdist(bbox_pred.to('cpu'), gt_bboxes.to('cpu'), p=1)
+        return bbox_cost.to(gt_bboxes.device) * self.weight
mmdet/core/bbox/samplers/pseudo_sampler.py

unique()でエラーになる。(※エラーの内容はメモり忘れてた…)

@@ -33,9 +33,9 @@ class PseudoSampler(BaseSampler):
             :obj:`SamplingResult`: sampler results
         """
         pos_inds = torch.nonzero(
-            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+            assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).cpu().unique().to(gt_bboxes.device)
         neg_inds = torch.nonzero(
-            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+            assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).cpu().unique().to(gt_bboxes.device)
         gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
         sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
                                          assign_result, gt_flags)
mmdet/models/dense_heads/detr_head.py

このファイルは2か所あって、1つ目はテンソルの一部をSlice指定で上書きするコードがDirectMLだとなぜか上書きされないという挙動になる。 2つ目はバッチ次元が0(教師データのBBox数が0個)の時に [0, 4] shape との演算にDirectMLが対応していなくて例外になる。

RuntimeError: self must have at least one element!

@@ -244,10 +244,11 @@ class DETRHead(AnchorFreeHead):
         # ignored positions, while zero values means valid positions.
         batch_size = x.size(0)
         input_img_h, input_img_w = img_metas[0]['batch_input_shape']
-        masks = x.new_ones((batch_size, input_img_h, input_img_w))
+        masks = x.new_ones((batch_size, input_img_h, input_img_w)).cpu()
         for img_id in range(batch_size):
             img_h, img_w, _ = img_metas[img_id]['img_shape']
             masks[img_id, :img_h, :img_w] = 0
+        masks = masks.to(x.device)

         x = self.input_proj(x)
@@ -537,8 +538,8 @@ class DETRHead(AnchorFreeHead):
         # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
         factor = bbox_pred.new_tensor([img_w, img_h, img_w,
                                        img_h]).unsqueeze(0)
-        pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
-        pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+        pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor if len(sampling_result.pos_gt_bboxes) else sampling_result.pos_gt_bboxes
+        pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) if len(sampling_result.pos_gt_bboxes) else sampling_result.pos_gt_bboxes
         bbox_targets[pos_inds] = pos_gt_bboxes_targets
         return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
                 neg_inds)

torch.device関連コード(参考)

説明が面倒になってきたのでそのままソースコードの差分だけ貼っておきます。

ちゃんと対応するにはmmcv側から改造が必要になるのと、mmdetection内ではtorch.deviceを使わずデバイス名のstrを受け取る形で実装されているので、以下の箇所以外に色々修正しないとダメだったりで不完全なので。

mmdet/apis/inference.py
@@ -151,6 +151,7 @@ def inference_detector(model, imgs):
             assert not isinstance(
                 m, RoIPool
             ), 'CPU inference with RoIPool is not supported currently.'
+        data['img'] = [cpu_tensor.to(device) for cpu_tensor in data['img']]

     # forward the model
     with torch.no_grad():
mmdet/apis/train.py
@@ -41,6 +41,10 @@ def init_random_seed(seed=None, device='cuda'):
     if world_size == 1:
         return seed

+    if device == 'dml':
+        import torch_directml
+        device = torch_directml.device()
+
     if rank == 0:
         random_num = torch.tensor(seed, dtype=torch.int32, device=device)
     else:
mmdet/utils/util_distribution.py
@@ -33,6 +33,12 @@ def build_dp(model, device='cuda', dim=0, *args, **kwargs):
         from mmcv.device.mlu import MLUDataParallel
         dp_factory['mlu'] = MLUDataParallel
         model = model.mlu()
+    elif device == 'dml':
+        import torch_directml
+        from mmdet.device.dml import DMLDataParallel
+        dp_factory['dml'] = DMLDataParallel
+        dml = torch_directml.device()
+        model = model.to(dml)

     return dp_factory[device](model, dim=dim, *args, **kwargs)
@@ -55,7 +61,7 @@ def build_ddp(model, device='cuda', *args, **kwargs):
                      DistributedDataParallel.html
     """
     assert device in ['cuda', 'mlu',
-                      'npu'], 'Only available for cuda or mlu or npu devices.'
+                      'npu', 'dml'], 'Only available for cuda or mlu or npu devices.'
     if device == 'npu':
         from mmcv.device.npu import NPUDistributedDataParallel
         torch.npu.set_compile_mode(jit_compile=False)

@@ -81,9 +93,18 @@ def is_mlu_available():
     return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()


+def is_dml_available():
+    try:
+        import torch_directml
+        return torch_directml.is_available()
+    except ImportError as e:
+        return False
+
+
 def get_device():
     """Returns an available device, cpu, cuda or mlu."""
     is_device_available = {
+        'dml': is_dml_available(),
         'npu': is_npu_available(),
         'cuda': torch.cuda.is_available(),
         'mlu': is_mlu_available()
mmdet/device/dml 配下

こちらは本来はmmcvに入れるべきコード。面倒なのでmmdetection配下に入れた。

# __init__.py
from ._functions import scatter, scatter_kwargs
from .data_parallel import DMLDataParallel
from .distributed import DMLDistributedDataParallel


__all__ = ['scatter', 'scatter_kwargs', 'DMLDataParallel', 'DMLDistributedDataParallel']


# _functions.py
import torch
import torch_directml
from typing import Union, List
from mmcv.parallel.data_container import DataContainer
from mmcv.device._functions import Scatter


def _scatter_core(current_device: torch.device, obj: Union[List, torch.Tensor]):
    if isinstance(obj, list):
        return [_scatter_core(current_device, elem) for elem in obj]
    elif isinstance(obj, torch.Tensor):
        return obj.to(current_device)
    else:
        raise RuntimeError(f'obj is unsupported type {type(obj)}')


def _scatter_data_container(current_device: torch.device, obj: DataContainer):
    outputs = _scatter_core(current_device, obj.data)
    return tuple(outputs) if isinstance(outputs, list) else (outputs,)


def scatter(inputs, target_devices, dim=0):
    device_id = next(iter(target_devices), torch_directml.default_device())
    current_device = torch_directml.device(device_id)

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            if target_devices != [-1]:
                obj = obj.to(current_device)
                return [obj]
            else:
                # for CPU inference we use self-implemented scatter
                return Scatter.forward(target_devices, obj)
        if isinstance(obj, DataContainer):
            if obj.cpu_only:
                return obj.data
            else:
                return _scatter_data_container(current_device, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            out = list(map(list, zip(*map(scatter_map, obj))))
            return out
        if isinstance(obj, dict) and len(obj) > 0:
            out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
            return out
        return [obj for _ in target_devices]

    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None


def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
    inputs = scatter(inputs, target_devices, dim) if inputs else []
    kwargs = scatter(kwargs, target_devices, dim) if kwargs else []

    if len(inputs) < len(kwargs):
        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
    elif len(kwargs) < len(inputs):
        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])

    inputs = tuple(inputs)
    kwargs = tuple(kwargs)

    return inputs, kwargs


# data_parallel.py
import torch_directml
from mmcv.parallel import MMDataParallel
from ._functions import scatter_kwargs


class DMLDataParallel(MMDataParallel):
    def __init__(self, *args, dim=0, **kwargs):
        super().__init__(*args, dim=dim, **kwargs)

        self.device_ids = kwargs.get('device_ids', [torch_directml.default_device()])
        self.src_device_obj = torch_directml.device(self.device_ids[0])

    def scatter(self, inputs, kwargs, device_ids):
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

venvのトップディレクトリにPyTorchの重みファイル(pth)を置いちゃダメって話

横着したらハマったので自戒を込めてメモ。

まとめ

  • Pythonには「サイト固有の設定フック」という機能がある
    • venv環境のトップディレクトリに入っている .pth ファイルを読み込んで処理する
  • PyTorchの学習済み重みファイル(拡張子 .pth)を置いていると上記機能と衝突して誤動作する
    • エラーでpython.exeが実行できなくなる。pipコマンドもエラーで実行できなくなる
  • venv環境と作業ディレクトリはツリーを分けよう(戒め)

サイト固有の設定フック機能

  • venv環境の直下とlib\site-packagesディレクトリに入っている .pth ファイルを読み込むらしい
    • 後述するsite.pyにprint文を埋め込んで上記2ディレクトリが対象になっていることを確認した
  • .pthファイルにはパス名が書かれている前提っぽい

エラーの内容

venvのトップディレクトリ(以下の例だとE:\soft\env_torch直下)に重みファイル(.pth)を置いてからpython.exeを実行しようとすると以下のようなエラーになる。

(env_torch) E:\soft\env_torch>python
Fatal Python error: init_import_size: Failed to import the site module
Python runtime state: initialized
Traceback (most recent call last):
  File "E:\soft\Python38\lib\site.py", line 580, in <module>
    main()
  File "E:\soft\Python38\lib\site.py", line 563, in main
    known_paths = venv(known_paths)
  File "E:\soft\Python38\lib\site.py", line 495, in venv
    addsitepackages(known_paths, [sys.prefix])
  File "E:\soft\Python38\lib\site.py", line 350, in addsitepackages
    addsitedir(sitedir, known_paths)
  File "E:\soft\Python38\lib\site.py", line 208, in addsitedir
    addpackage(sitedir, name, known_paths)
  File "E:\soft\Python38\lib\site.py", line 164, in addpackage
    for n, line in enumerate(f):
UnicodeDecodeError: 'cp932' codec can't decode byte 0x8a in position 2: illegal multibyte sequence

PyTorchの重みファイルはバイナリファイルだが、これをパス名が記載されているテキストファイルとして読み込もうとしてエラーになっているっぽい。

エラーメッセージは文字コード関連のメッセージだが騙されてはいけない。環境変数PYTHONUTF8を設定しても無駄である。

pythonもpipも実行できなくなるので問題のデバッグ自体がつらくなるので注意。今回はsite.pyを直接編集してprint文を埋め込んでデバッグした。

結論

  • venv環境の直下にはPyTorchの重みファイルを置いてはいけない
  • そもそも作業ディレクトリを別ツリーに分けておけば回避できる(横着はよくなかった)

一応.ptファイルなら誤動作しないと思われるが、そこは問題の本質ではないと思う。論文の実装コードで自動ダウンロードが走ったりすることがあるし。

CUDA実装のONNXRuntimeカスタムオペレータを実装してみた

前回の続き。

前回の記事 ↓ maminus.hatenadiary.org

今回のソースコードgithub.com

ありがたいことにGitHubのIssueでCUDA版のカスタムオペレータ実装方法について問い合わせをいただいたので時間ができた時に実装してみた。

※ONNXRuntimeのカスタムオペレータについては前回の記事を参照

ONNXRuntimeのExecution Providerについて

CUDA版のカスタムオペレータを実装する場合、一応以下の2パターン方法が考えられる。

  1. カスタムオペレータだけCUDAで実行し、他のオペレータはCPUで実行する
  2. すべてCUDAで実行する

1のパターンはあまりうれしくないと思うので2のカスタムオペレータも他のオペレータもCUDAで動かすことを考える。

この時デフォルトだとONNXRuntimeはCPUで動くのでCUDA版のExecution Providerというのを指定して動かす必要がある。

Execution ProviderというのはONNXRuntime内のフレームワークらしくて、推論処理などのハードウェアアクセラレーションが受けられそうな箇所を切り替え可能な機能ブロックとして分離しているらしい。

詳細は以下の公式ドキュメントを参照のこと。 onnxruntime.ai

そして、CUDAExecutionProviderを使うと推論処理をCUDA版で実行できるらしい。

CUDA版のカスタムオペレータを実装する時にはCUDAExecutionProvider版のカスタムオペレータとして実装すれば全体がCUDAで実行できることになる。

CUDAExecutionProviderでカスタムオペレータを実装する際のポイント

少しだけ前回の復習をしておくとC++でカスタムオペレータを実装する時には主に以下の要素を実装した。

実装すべきもの 主な用途
kernelクラス 計算処理本体
オペレータクラス カスタムオペレータの仕様(引数の個数など)に関する情報を返すクラス
Register関数など ONNXRuntimeへの登録

CUDA版を実装する際には大雑把には以下のような実装内容の違いがある。

実装すべきもの CPU版 CUDA版
kernelクラス 計算処理本体を実装する CUDAホスト関数を呼び出す
オペレータクラス CPUExecutionProviderとして実装する CUDAExecutionProviderとして実装する
Register関数など CPU/CUDAどちらも同じ CPU/CUDAどちらも同じ
CUDAソースコード 実装不要 計算処理本体を実装する

具体的な変更点は以下の通り。

  • kernelクラスのCompute()メソッド
    • GetTensorData()はCUDAのデバイスポインタが返ってくるのでそのままCUDA関数に渡してOK
    • CUDAのStream IDはKernelContext_GetGPUComputeStream()で取得できる(該当コード
  • オペレータクラスのGetExecutionProviderType()メソッド
    • CPU版ではオーバーライド不要
    • 文字列"CUDAExecutionProvider"を返すようにオーバーライドすることでCUDAExecutionProviderになる(該当コード

Pythonの呼び出しコード

Pythonの推論処理コードはInferenceSessionのコンストラクタ引数にExecutionProviderのリストを追加で指定するようにすればOK。

import onnxruntime as ort


model = ...

# カスタムオペレータを実装したDLLをロードする
option = ort.SessionOptions()
option.register_custom_ops_library('./libmy_custom_multi_with_cuda.so')

# 使いたいExecutioinProviderを指定する
providers = ['CUDAExecutionProvider']

# カスタムオペレータDLLとExecutionProviderを指定してセッションを生成する
sess = ort.InferenceSession(model.SerializeToString(), option, providers)

# sess.run()で推論できる
...

お試しコードか上で紹介した公式ドキュメントも参考に。

ONNXRuntimeのカスタムオペレータを実装してみた

ソースコードはここ↓ github.com

背景とか

TensorRTを使おうとしてモデルの一部をONNXのカスタムオペレータにすることがある。(TensorRTのプラグインを使うケース)

ただカスタムオペレータを含むモデルはそのままではONNXRuntimeで推論できない。ということはONNX Simplifierのようなツールを使うこともできない。 カスタムオペレータを実装することで推論が可能になってツール類も使える。

# ONNXのカスタムオペレータはノードの'domain' attributeに独自ドメイン名を指定すれば作れる
nodes = [onnx.helper.make_node('Fma', ['A', 'B', 'C'], ['out'], domain='ai.onnx.contrib')]

カスタムオペレータの実装手段

少し調べたところ、大きく分けて2つのやり方がある。

  1. Pythonでカスタムオペレータを実装する
  2. C++でカスタムオペレータを実装する

Pythonで実装すると推論コードもPythonで記載できてPythonのみで実現できるので楽ちん。ただし、onnxruntime-extensionsパッケージが必要になるのと、強めの制約がある。(後述)

C++は実装が面倒だがPythonよりは制約を緩められる。

Pythonで実装する方法

onnx_opデコレータをカスタムオペレータ実装ルーチンにつけるだけ。引数のテンソルはnumpy.ndarrayが渡される。戻り値のテンソルもndarrayを返せばOK。

# 引数はすべてfloat32型、戻り値もfloat32型。op_typeは'Fma'
@onnx_op(op_type='Fma', inputs=[PyOp.dt_float, PyOp.dt_float, PyOp.dt_float], outputs=[PyOp.dt_float])
def fma(a, b, c):
    return a * b + c

ただし、float32バージョンfloat64バージョン、のように扱うデータ型のバリデーションを作ることができない。 これはop_typeに対するルーチンが1つしか登録できない作りになっているためと思われる。

推論は以下のようにする。以下のコードでmodel_func()の呼び出しがONNXモデル全体の推論実行処理になっている。

model_func = PyOrtFunction.from_model(_ONNX_FILE_NAME)
result = model_func(A, B, C)

C++で実装する方法

C++で実装する場合は主に以下の要素を用意すればよい。

  • void Compute(OrtKernelContext* context)メソッドを持つkernelクラス
    • 計算処理本体を実装する
  • Ort::CustomOpBase<Op, Kernel>クラスを継承し必要なメソッドを実装したクラス
    • void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const
    • const char* GetName() const
      • op_type名を返す
    • ONNXTensorElementDataType GetInputType(size_t index) const
      • index番目の入力データ型を返す
    • size_t GetInputTypeCount() const
      • オペレータ引数の数を返す
    • OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const
      • 省略可能な引数を持つ場合に実装する
      • index番目の引数が必須か省略可能かを返す
    • ONNXTensorElementDataType GetOutputType(size_t index) const
      • index番目の戻り値のデータ型を返す
    • size_t GetOutputTypeCount() const
      • オペレータの戻り値の数を返す
    • OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const
      • 省略可能な戻り値を持つ場合に実装する
      • index番目の戻り値が必須か省略可能かを返す
  • RegisterCustomOps()関数

ソースコードのビルドはONNXRuntimeの3つのヘッダファイルがあればOK。 CMakeLists.txtではfind_path()でヘッダファイルの場所を探すようにしているので参考に。

kernelクラス

kernelクラスは大雑把に以下の構造になるように実装する。

struct FmaKernel {
    FmaKernel(OrtApi api):api_(api), ort_(api_) {}

    void Compute(OrtKernelContext* context) {
        // ... カスタムオペレータの計算処理
    }
private:
    OrtApi api_;
    Ort::CustomOpApi ort_;
};

入力データへのポインタなどの計算に必要なデータはOrt::CustomOpApiでアクセスできる。ただし、CustomOpApiクラスはコンストラクタ引数のOrtApiインスタンスを参照で保持するためOrtApiインスタンスのコピーをkernelクラスのメンバに保持する必要があるとのこと。

// 0番目の引数(float型)へのポインタをもらう例
const auto input_a = ort_.KernelContext_GetInput(context, 0);
auto ptr_a = ort_.GetTensorData<float>(input_a);

// 出力0を[1, 3, 224, 224]のshapeで作ってポインタをもらう例
size_t shape_dim = 4;
const int64_t shape[shape_dim] = {1, 3, 224, 224};
auto output_0 = ort_.KernelContext_GetOutput(context, 0, shape, shape_dim);
auto ptr_0 = ort_.GetTensorMutableData<float>(output_0);

もし出力shapeが入力と同じならGetTensorTypeAndShape()GetTensorShape()を呼ぶと入力shapeをもらえるので出力shapeの指定にそのまま使えばよい。

オペレータクラス

オペレータクラスは以下のようにCustomOpBaseクラスを継承する。CustomOpBaseクラスのテンプレート引数には自分自身とkernelクラスを指定する。

struct CustomOpFma : Ort::CustomOpBase<CustomOpFma, FmaKernel> {
    // 対応するkernelクラスのインスタンスをnewして返す
    void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
        return new FmaKernel(api);
    }
    // 単純に対応するop_type名を返すだけでOK
    const char* GetName() const {
        return "Fma";
    }
    // ...
};

後のI/Fは特筆すべき内容は無いのでinput側だけ掲載する。具体的な実装コードはGitHubにpushした実装コードを参照のこと。

 // index番目の引数のデータ型をONNXTensorElementDataType(onnxruntime_c_api.hで定義されている)で返す
    ONNXTensorElementDataType GetInputType(size_t index) const {
        if (index > 0) {
            // "T"を指定可能なのは1つの引数のみ。残りはFLOATなどの具体的な型を返す必要がある
            return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
        }
        // UNDEFINEDを返すとデータ型は"T"(任意型)として扱われる
        return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
    }
    // 引数の数を返す。Fmaの例なら引数は3つなので3を返せばOK
    size_t GetInputTypeCount() const {
        return 3;
    }
    // index番目の引数が省略可能ならOPTIONAL、必須ならREQUIREDを返す
    OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
        if (index > 1) {
            // 3つ目(index == 2)の引数を省略可能とする例
            return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
        }
        return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
    }

注意点

float32版オペレータfloat64版オペレータなどと複数データ型に対応させたい場合に問題点がある。2つのやり方がある。

  1. データ型をONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINEDで報告する
  2. データ型ごとに別々のドメイン名に分ける

1の方法はデータ型を"T"(任意型)に指定する方法だが、任意型をとる引数が1つのみのケースしか使えない。FMAのように3つの引数を"T"と指定することができない。

2の方法はONNXモデル側でドメインを分ける必要が発生してしまうかわりに各データ型に対応する実装を作ることができる。もしデータ型ごとにop_type名を変えることができるならドメイン名は同じでop_type名で分ける方法もある。 (カスタムオペレータはop_type名につき1種類の実装しか登録できないためop_type名かドメイン名を分ける必要がある)

RegisterCustomOps関数

このルーチンはDLLのロード時(正確にはregister_custom_ops_library()呼び出し時)に呼ばれる。

constexpr std::string_view domain_name = "my_ops";
CustomOpFma op_fma;

extern "C" {
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
    const OrtApi* ort_api = api_base->GetApi(ORT_API_VERSION);
    OrtStatus* status;

    // ドメインオブジェクトを作る
    OrtCustomOpDomain* domain = nullptr;
    if (status = ort_api->CreateCustomOpDomain(domain_name.data(), &domain)) {
        return status;
    }

    // ドメインオブジェクトを(後で解放処理を実行するために)覚えておく
    register_domain(domain, ort_api);

    // ドメインオブジェクトにカスタムオペレータを登録する
    if (status = ort_api->CustomOpDomain_Add(domain, &op_fma)) {
        return status;
    }

    // ドメインオブジェクトをONNXRuntimeに登録する
    status = ort_api->AddCustomOpDomain(options, domain);
    return status;
}
}   // extern "C"

register_domain()は何かのI/Fとかではなく、今回独自に実装した。やっていることはドメインオブジェクトのポインタをdeleter付きのunique_ptrでくるんでvectorに登録しているだけ。 ドメインオブジェクトだけ?はインスタンスの削除をこちらで実施する必要があるっぽい。そしてdeleteで削除するのではなくRelease系メソッド呼び出しが必要らしい。

具体的な処理内容は実装コードを参照のこと。

推論処理

Pythonで推論するには以下のようにInferenceSessionの引数にSessionOptionsを渡せばよい。(ONNXモデルのカスタムオペレータのドメイン名、op_type名とC++実装コードのドメイン名、CustomOpFma::GetName()が返す名前が一致している必要がある)

import onnx
import numpy as np
import onnxruntime as ort

option = ort.SessionOptions()
option.register_custom_ops_library('./libmy_custom_t.so')

model = onnx.load(_ONNX_FILE_NAME)
sess = ort.InferenceSession(model.SerializeToString(), option)

A = np.ones([1], dtype=np.float32)
B = np.ones([1], dtype=np.float32)
results = sess.run(None, {'A': A, 'B': B})

TensorRTビルド(とEfficientDet実行環境)のDockerfile作ってみた

AutoML版EfficientDetをTensorRT化しようとして色々あってTensorRTのビルド環境と一緒に環境作りたくなったのでDockerfileとscriptを作ってgitにpushしてみた。

github.com

はまったことや注意点を列挙したい。

pycudaより前にnumpyのインストールが必要

setup.pyでひたすらエラーになって何度もバージョンをかけてsetup.pyを実行しようとしてしまうので先にnumpyをインストールする

ARM版bazelのインストール方法がややこしかった

結局実行ファイルを直接ダウンロードして配置する方法しかダメだった

tensorflow-addonsのバージョン依存パナイ

tensorflow-hubとtensorflow-addonsはARM版パッケージが登録されてなさそうなのでソースコードからビルドした。

addonsはバージョン依存がきつくて一覧表とにらめっこしてふさわしいバージョンを探す必要がある。

AutoML版EfficientDetのバージョン依存

masterブランチだと指定されているバージョンのライブラリを用意できないので1.2のソースコードを使うしかなかった。ただし、kerasディレクトリが本家Kerasパッケージと名称衝突してる件とかTensorFlow v2.6のtf.data.experimental.OptimizationOptionsからmap_vectorizationが消えていたりして1.2のソースコードそのままだと実行ができなかった。
(さすがにこの部分はスクリプトにも含めなかった)

TensorRTでfloat型tensorをResizeやSliceの第二以降の引数に指定するモデルが内部エラーになるのでフォーラムに連絡してみた

とあるディープラーニングモデルをONNXに変換してTensorRTで動かそうとしたらInternal Errorが出るのでNVIDIAさんのフォーラムに凸ってみた。

該当スレッドはここ

モデルに関する簡単な説明

まずはONNXのResizeオペレータとSliceオペレータについて。

Resizeoutput_feature = Resize(input_feature, roi, scales, sizes)のような形式で第四引数sizesで拡大・縮小後の画像サイズを指定すると拡大・縮小してくれるオペレータ。 Sliceoutput_feature = Slice(input_feature, starts, ends, axes, steps)のような形式で第二引数以降で指定したスライス操作を行ってくれるオペレータ。

これらのオペレータに対して第二引数以降の引数にfloat型tensorでサイズ計算などを実行してから整数型にキャストしたものを入れていた。例えばResizeには直前のConvolution結果のshapeを2倍に拡大するためにfloat型でwidth, heightを×2してからFloorで端数切り捨て+キャストで整数型にしてsizesに渡していた。

フォーラムに投稿

TensorRTはプラグインとモデルのパーサーについてはオープンソース公開されているが、他の部分はプロプリエタリでソースコードは非公開になっている。内部エラーということで、NVIDIAの中の人しか対応できない、ということでフォーラムに連絡することにした。

Issue報告時のガイドラインによると最小セットの再現モデルを提示した方が早そうだったので、不要なオペレータを削除しまくって再現するモデルを作った。

結論

いくつかディスカッションがあったので、結論をまとめると以下の通り。

  • Resizeのサイズ指定やSliceのスライス指定にfloat型のtensorを使えない
    • 途中でCastで整数型にしていてもダメ
  • TensorRT 8.4 でfloat型tensorに対応される
    • JetPack 5.0 DP(TensorRT 8.4.0 EA)で対応されていることを確認済み(2022/05/01 追記)
  • ただし、TensorRT 8.4 でもtensorをたどった先がモデル入力だった場合は非対応
  • プラグイン出力はResizeのサイズ指定やSliceのスライス指定に使えない(非対応)

補足

ディスカッションおよび途中で教えてもらったドキュメント8.5節の内容をざっくり整理すると以下の通り。

  • ReiszeやSliceの第二引数以降の引数はshape tensorとして扱われる
  • shape tensorはint32またはboolのみ対応
  • shape tensorは0次元または1次元のみ対応
  • shape tensorの計算はnetworkのビルド(ディープラーニングモデルを内部のエンジン形式に変換)する際のフェーズ1で計算される
  • shape計算はモデルのノードを入力側にたどりながら計算する
    • たどる時にshape tensorとして解釈できないノードがあると内部エラーになっているっぽい(想像)
  • shape計算はTensorRTがCPUで実行する
  • 実際にモデルの推論処理を動かすわけではないため、shape計算には制約がある
    • 例えばTensorRT 8.4でfloat型に対応してもモデル入力がshape tensorとして解釈される時にデータ型floatのケースは対応していない
    • 整数型にキャストしていてもfloat型未対応の制約を回避できない
    • プラグイン出力はshape tensorとして使えない

PyTorchモデルをONNX化する際にカスタムオペレータの実装が必要っぽい件

やりたかったこと

  • PyTorchモデルをONNXファイルにexportしたい
  • モデルにはカスタムオペレータを含む
  • カスタムオペレータの実装は存在しない(該当部分は推論不能
  • カスタムオペレータはONNXでそれらしきop_typeを持つノードで出力してくれればOK

やってみたこと

  • register_custom_op_symbolic()で'custom_ops::my_operator'名でONNXオペレータ定義のルーチンを登録
  • forward()メソッドでtorch.ops.custom_ops.my_operator(...)呼び出し
  • torch.onnx.export()でモデルを変換

結果

カスタムオペレータの実装を探しに行って例外発生。

  File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torch/_ops.py", line 61, in __getattr__
    op = torch._C._jit_get_operation(qualified_op_name)
RuntimeError: No such operator custom_ops::batched_nms

考察(推測)

  • ONNX変換時にカスタムオペレータの実装コードが必要(っぽい)
  • register_custom_op_symbolic()は(おそらく)ONNX変換の時に使うだけで変換時のダミー推論には使われない
  • (おそらく)トレースありの状態でダミー推論を実行して、トレース中に呼び出されたオペレータをONNXに変換する、という形でONNX変換が実現されているような気がする
    • なので推論自体が実行できないとダメ、ということだと思われる
    • おそらくカスタムオペレータの実装はそれらしきダミーデータを固定値で返す実装でもONNX変換には影響なさそう

メモ

カスタムオペレータを実装する場合はこちらの公式チュートリアルを参考にすればよさそう。