PyTorch RNN-T モデルの変換#

危険

ここで説明されているコードは非推奨になりました。従来のソリューションの適用を避けるため使用しないでください。下位互換性を確保するためにしばらく保持されますが、最新のアプリケーションでは使用してはなりません

このガイドでは、非推奨となった変換方法について説明します。新しい推奨方法に関するガイドは、Python チュートリアルに記載されています。

このガイドでは、MLCommons リポジトリーからの RNN-T モデルの変換について説明します。IR に変換する前に、以下の手順に従って PyTorch モデルを ONNX にエクスポートします:

ステップ 1.MLCommons リポジトリー (リビジョン r1.0) から RNN-T PyTorch 実装のクローンを作成します。完全なリポジトリーを使用せずに RNN-T モデルのみを取得する浅いクローンを作成します。すでに完全なリポジトリーがある場合、これをスキップして ステップ 2 に進みます:

git clone -b r1.0 -n https://github.com/mlcommons/inference rnnt_for_openvino --depth 1 
cd rnnt_for_openvino 
git checkout HEAD speech_recognition/rnnt

ステップ 2.MLCommons 推論リポジトリーの完全なクローンがすでに存在する場合、IR への変換が行われる事前トレーニング済みの PyTorch モデル用のフォルダーを作成します。ステップ 5 では、完全なクローンへのパスも指定する必要があります。浅いクローンがある場合は、この手順をスキップしてください。

mkdir rnnt_for_openvino 
cd rnnt_for_openvino

ステップ 3.こちらから、PyTorch 実装用の事前トレーニングされた重みをダウンロードします。UNIX のようなシステムでは wget を使用できます:

wget https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt

リンクは speech_recoginitin/rnnt サブフォルダー内の setup.sh から取得されました。ガイドに従った場合と全く同じ重みが得られます。

ステップ 4.必要な Python パッケージをインストールします:

pip3 install torch toml

ステップ 5.以下のスクリプトを使用して、RNN-T モデルを ONNX にエクスポートします。以下のコードを export_rnnt_to_onnx.py ファイルにコピーし、現在のディレクトリー rnnt_for_openvino で実行します。

MLCommons 推論リポジトリーの完全なクローンがすでにある場合は、mlcommons_inference_path 変数を指定する必要があります。

import toml 
import torch 
import sys 

def load_and_migrate_checkpoint(ckpt_path): 
    checkpoint = torch.load(ckpt_path, map_location="cpu") }
    migrated_state_dict = {} 
    for key, value in checkpoint['state_dict'].items(): 
        key = key.replace("joint_net", "joint.net") 
        migrated_state_dict[key] = value 
    del migrated_state_dict["audio_preprocessor.featurizer.fb"] 
    del migrated_state_dict["audio_preprocessor.featurizer.window"] 
    return migrated_state_dict 

mlcommons_inference_path = './'# specify relative path for MLCommons inferene 
checkpoint_path = 'DistributedDataParallel_1576581068.9962234-epoch-100.pt' 
config_toml = 'speech_recognition/rnnt/pytorch/configs/rnnt.toml' 
config = toml.load(config_toml) 
rnnt_vocab = config['labels']['labels'] 
sys.path.insert(0, mlcommons_inference_path + 'speech_recognition/rnnt/pytorch') 

from model_separable_rnnt import RNNT 

model = RNNT(config['rnnt'], len(rnnt_vocab) + 1, feature_config=config['input_eval']) 
model.load_state_dict(load_and_migrate_checkpoint(checkpoint_path)) 

seq_length, batch_size, feature_length = 157, 1, 240 
inp = torch.randn([seq_length, batch_size, feature_length]) 
feature_length = torch.LongTensor([seq_length]) 
x_padded, x_lens = model.encoder(inp, feature_length) 
torch.onnx.export(model.encoder, (inp, feature_length), "rnnt_encoder.onnx", opset_version=12, 
                  input_names=['input', 'feature_length'], output_names=['x_padded', 'x_lens'], 
                  dynamic_axes={'input': {0: 'seq_len', 1: 'batch'}}) 

symbol = torch.LongTensor([[20]]) 
hidden = torch.randn([2, batch_size, 320]), torch.randn([2, batch_size, 320]) 
g, hidden = model.prediction.forward(symbol, hidden) 
torch.onnx.export(model.prediction, (symbol, hidden), "rnnt_prediction.onnx", opset_version=12, 
                  input_names=['symbol', 'hidden_in_1', 'hidden_in_2'], output_names=['g', 'hidden_out_1', 'hidden_out_2'], 
                  dynamic_axes={'symbol': {0: 'batch'}, 'hidden_in_1': {1: 'batch'}, 'hidden_in_2': {1: 'batch'}}) 

f = torch.randn([batch_size, 1, 1024]) 
model.joint.forward(f, g) 
torch.onnx.export(model.joint, (f, g), "rnnt_joint.onnx", opset_version=12, 
                   input_names=['0', '1'], output_names=['result'], dynamic_axes={'0': {0: 'batch'}, '1': {0: 'batch'}})
python3 export_rnnt_to_onnx.py

この手順が完了すると、ファイル rnnt_encoder.onnxrnnt_prediction.onnx、および rnnt_joint.onnx が現在のディレクトリーに保存されます。

ステップ 6.変換コマンドを実行します:

mo --input_model rnnt_encoder.onnx --input "input[157,1,240],feature_length->157" 
mo --input_model rnnt_prediction.onnx --input "symbol[1,1],hidden_in_1[2,1,320],hidden_in_2[2,1,320]" 
mo --input_model rnnt_joint.onnx --input "0[1,1,1024],1[1,1,320]"

シーケンス長 = 157 のハードコードされた値は MLCommons から取得されたものですが、IR への変換によりネットワークの再形成の可能性が維持されます。入力形状は、変換中または推論中に手動で任意の値に変更できます。