畳み込みモデルのフィルター・プルーニング#
はじめに
フィルターのプルーニング (枝刈り) は、モデルの畳み込み演算から冗長なフィルターや重要でないフィルターを削除することで、モデル計算の複雑さを軽減できる高度な最適化手法です。この削除は 2 つの手順で行われます:
重要でないフィルターは、微調整による NNCF 最適化によってゼロにされます。
ゼロフィルターは、OpenVINO 中間表現 (IR) へのエクスポート中にモデルから削除されます。
NNCF のフィルター・プルーニング・メソッドはスタンドアロンで使用できますが、通常は 2 つの理由から 8 ビット量子化と組み合わせることを推奨します。まず、8 ビット量子化は、最高の精度とパフォーマンスのトレードオフを達成するという点で最良の方法であるため、フィルター・プルーニングと組み合わせることで、さらに優れたパフォーマンスを得ることができます。次に、フィルター・プルーニングとともに量子化を適用すると、精度が大きく損なわれることはありません。これは、フィルター・プルーニングによってモデルからノイズの多いフィルターが削除され、重みと活性化値の範囲が狭まり、全体的な量子化の誤差が軽減されるためです。
注
フィルター・プルーニングには通常、モデルを最初からトレーニングすることに匹敵する、長時間にわたるモデルの微調整または再トレーニングが必要です。それを怠ると、大幅な精度低下を引き起こす可能性があります。したがって、この方法を適用する場合、それに応じてトレーニング・スケジュールを調整する必要があります。
以下に、フィルター・プルーニング + QAT をモデルの適用に必要な手順を示します:
微調整によるフィルター・プルーニングの適用#
ここでは、モデルのトレーニング・スクリプトを変更し、重要でないフィルターをゼロにする基本的な手順を示します:
1. NNCF API をインポート#
このステップでは、トレーニング・スクリプトの先頭に NNCF 関連のインポート文を追加します:
import torch
import nncf # 重要 - トーチ直後にインポートする必要があります
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
import tensorflow as tf
from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, create_compression_callbacks, \
register_default_init_args
2. NNCF 構成の作成#
ここでは、モデル関連のパラメーター (“input_info” セクション) と最適化メソッドのパラメーター (“compression” セクション) で構成される NNCF 構成を定義する必要があります。
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # モデルのトレースに必要な入力形状
"compression": [
{
"algorithm": "filter_pruning",
"pruning_init": 0.1,
"params": { "pruning_target": 0.4, "pruning_steps": 15
}
},
{
"algorithm": "quantization", # デフォルト設定での 8 ビット量子化
},
]
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_loader) # train_loader は torch.utils.data.DataLoader のインスタンスです
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # モデルのトレースに必要な入力形状
"compression": [
{
"algorithm": "filter_pruning",
"pruning_init": 0.1,
"params": { "pruning_target": 0.4, "pruning_steps": 15
}
},
{
"algorithm": "quantization", # デフォルト設定での 8 ビット量子化
},
]
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_dataset, batch_size=1) # train_dataset は tf.data.Dataset のインスタンス
ここでは、フィルター・プルーニング・メソッドに必要なパラメーターについて簡単に説明します。完全な説明は、GitHub を参照してください。
pruning_init
- 初期のプルーニング (枝刈り) 率の目標。例えば、値0.1
は、トレーニングの開始時に、プルーニングできる畳み込みのフィルターの 10% がゼロに設定されることを意味します。pruning_target
- スケジュール終了時のプルーニング (枝刈り) 率の目標。例えば、値0.5
は、num_init_steps + pruning_steps
の数を持つエポックで、プルーニングできる畳み込みのフィルターの 50 パーセントがゼロに設定されることを意味します。pruning_steps - プルーニング率のターゲットが pruning_init から pruning_target に増加する間のエポック数。この期間中は最高の学習率を維持することを推奨します。
3. 最適化の適用#
このステップでは、前のステップで定義した構成を使用し、create_compressed_model()
API を使用して、元のモデルが NNCF オブジェクトによってラップされます。このメソッドは、圧縮コントローラーと、元のモデルと同じように使用できるラップされたモデルを返します。モデルが対応する一連の変換を実行し、最適化に必要な追加の操作を含めることができるように、このステップで最適化メソッドが適用されることに注意してください。
model = TorchModel() # torch.nn.Module のインスタンス
compression_ctrl, model = create_compressed_model(model, nncf_config)
model = KerasModel() # tensorflow.keras.Model のインスタンス
compression_ctrl, model = create_compressed_model(model, nncf_config)
4. モデルの微調整#
このステップでは、ベースライン・モデルに対して適用した方法でモデルに微調整を加えることを前提としています。フィルター・プルーニング法では、元のモデルのトレーニングに使用したのと同様のトレーニング・スケジュールと学習率を使用することを推奨します。
...# 微調整の準備、例: データセット、損失、オプティマイザーのセットアップなど
# ベースラインとして 50 エポックの量子化モデルを調整
for epoch in range(0, 50):
compression_ctrl.scheduler.epoch_step() # Epoch control API
for i, data in enumerate(train_loader):
compression_ctrl.scheduler.step() # トレーニング反復制御 API
...# トレーニング・ループ本体
...# 微調整の準備、例: データセット、損失、オプティマイザーのセットアップなど
# 最適化パラメーターを制御し、圧縮統計をダンプするための圧縮コールバックを作成
# すべての設定は compression_ctrl、つまり NNCF 設定から取得
compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir="./compression_log")
# ベースラインとして 50 エポックの量子化モデルを調整
model.fit(train_dataset, epochs=50, callbacks=compression_callbacks)
5. マルチ GPU 分散トレーニング#
マルチ GPU 分散トレーニング (DataParallel ではない) の場合、微調整の前に compression_ctrl.distributed()
を呼び出す必要があります。これにより、分散モードで機能するためいくつかの調整を行えるように最適化メソッドに通知されます。
compression_ctrl.distributed() # トレーニング・ループの前に呼び出し
compression_ctrl.distributed() # トレーニングの前に呼び出し
6. 量子化モデルのエクスポート#
微調整が終了したら、量子化モデルを対応する形式にエクスポートして、さらに推論を行うことができます: PyTorch およびフリーズされたグラフの ONNX - TensorFlow 2 の場合。
compression_ctrl.export_model("compressed_model.onnx")
compression_ctrl.export_model("compressed_model.pb") # Frozen グラフへエクスポート
これらは、NNCF の QAT メソッドを適用する基本的な手順です。ただし、状況によっては、トレーニング中にモデルのチェックポイントを保存/復元する必要があります。NNCF は元のモデルを独自のオブジェクトでラップするため、これらのニーズに対応する API を提供します。
7. (オプション) チェックポイントの保存#
モデルのチェックポイントを保存するには、次の API を使用します:
checkpoint = {
'state_dict': model.state_dict(),
'compression_state': compression_ctrl.get_compression_state(),
...# 保存する残りのユーザー定義オブジェクト
}
torch.save(checkpoint, path_to_checkpoint)
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
checkpoint = tf.train.Checkpoint(model=model,
compression_state=TFCompressionState(compression_ctrl),
...# 保存する残りのユーザー定義オブジェクト
)
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
model.fit(..., callbacks=callbacks)
8. (オプション) チェックポイントから復元#
チェックポイントからモデルを復元するには、次の API を使用します:
resuming_checkpoint = torch.load(path_to_checkpoint)
compression_state = resuming_checkpoint['compression_state']
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state=compression_state)
state_dict = resuming_checkpoint['state_dict']
model.load_state_dict(state_dict)
from nncf.tensorflow.utils.state import TFCompressionStateLoader
checkpoint = tf.train.Checkpoint(compression_state=TFCompressionStateLoader())
checkpoint.restore(path_to_checkpoint)
compression_state = checkpoint.compression_state.state
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state)
checkpoint = tf.train.Checkpoint(model=model, ...)
checkpoint.restore(path_to_checkpoint)
詳細は、次のドキュメントをご覧ください。
量子化モデルのデプロイ#
プルーニングされたモデルには、パフォーマンス向上のため実行すべき追加の手順が必要になることがあります。このステップには、モデルからのゼロフィルターの削除が含まれます。これは、モデルがフレームワーク表現 (ONNX、TensorFlow など) から OpenVINO 中間表現に変換されるときに、モデル・トランスフォーメーション API ツールのモデル変換ステップで行われます。
プルーニングされたモデルからゼロフィルターを削除するには、次のパラメーターをモデル変換コマンドに追加します:
transform=Pruning
その後は、ベースライン・モデルと同じ方法で OpenVINO を使用してデプロイできます。OpenVINO を使用したモデルのデプロイメントの詳細については、対応するドキュメントを参照してください。