CTCLoss#
バージョン名: CTCLoss-4
カテゴリー: シーケンス処理
簡単な説明: CTCLoss は、CTC (コネクショニズム時間分類) 損失を計算します。
詳細な説明:
CTCLoss は、ロジット logits[i,:,:] の指定された入力シーケンスに対してターゲット labels[i,:] が発生する可能性 (または実際に) を推定します。要約すると、CTCLoss 操作は、ターゲット labels[i,:] にアライメントされたすべてのシーケンスを見つけ、logits[i,:,:] でアライメントされたシーケンスの対数確率を計算し、これらの対数確率の負の和を計算します。
ロジット logits の入力シーケンスは、異なる長さにできます。各シーケンスの長さ logits[i,:,:] は logit_length[i] と等しくなります。ターゲットシーケンス labels[i,:] の長さは label_length[i] と等しくなります。ターゲットシーケンスの長さは、対応する入力シーケンス logits[i,:,:] の長さを超えてはなりません。それ以外の場合、操作の動作は未定義です。
CTCLoss 計算スキーム:
ソフトマックスの公式を使用して、
logitsからi番目の入力シーケンスのタイムステップtにおけるj番目の文字の確率を計算します:
指定された
i番目のターゲットに対して、labels[i,:]からすべての位置一致したパスを検索します。デコード後に両方のチェーンが等しい場合、パスS = (c1,c2,...,cT)はターゲットG=(g1,g2,...,gT)に位置合わせます。デコードでは、ターゲットGから長さlabel_length[i]の部分文字列が抽出され、preprocess_collapse_repeated が true の場合はG内の繰り返し文字がマージされ、unique が true のは文字の出現順序で一意の要素が検索されます。デコードでは、ctc_merge_repeated が true の場合にS内の繰り返し文字がマージされ、blank_indexで表される空白文字が削除されます。デフォルトでは、blank_indexはC-1に等しくなります。ここで、Cはブランクを含むクラスの数です。例えば、デフォルトの ctc_merge_repeated、preprocess_collapse_repeated、unique そしてblank_indexの場合、長さlabel_length[i]=4のターゲットシーケンスG=(0,3,2,2,2,2,2,4,3)は(0,3,2,2)、長さlogit_length[i]=9のパスS=(0,0,4,3,2,2,4,2,4)も(0,3,2,2)になります。ここでC=5です。0,4,3,3,2,4,2,2,2など、Gと一致する他のパスも存在します。ターゲットlabel[:,i]との位置合わせがチェックされるパスは、長さがlogit_length[i] = L_iである必要があります。位置合わせされたパス (位置合わせ) の確率を次のように計算します:
最後に、見つかったすべてのアライメントの合計確率の負の対数を計算します:
注 1: この計算スキームは、最適な実装の手順を提供するものではなく、説明をわかりやすくするために役立ちます。
注 2: これは、整列されたパスの対数確率 \(\ln p(S)\) を入力ロジットの log-softmax の合計として計算することを推奨します。計算中のアンダーフローやオーバーフローを回避するのに役立ちます。整列されたパスの対数確率があれば、これらのパスの合計確率の対数は次のように計算できます:
属性
preprocess_collapse_repeated
説明: preprocess_collapse_repeated は、損失計算の前の前処理ステップのフラグであり、損失に渡される
labels[i,:]内の繰り返しラベルが単一のラベルにマージされます。値の範囲: true または false
タイプ:
booleanデフォルト値 : false
必須: いいえ
ctc_merge_repeated
説明: ctc_merge_repeated は、CTC 損失計算中の潜在的な位置合わせで繰り返される文字をマージするためのフラグです。
値の範囲: true または false
タイプ:
booleanデフォルト値: true
必須: いいえ
unique
説明: unique は、潜在的なアライメントと照合する前に、ターゲット
labels[i,:]の一意の要素を検索するフラグです。処理されたlabels[i,:]内の固有の要素は、元のlabels[i,:]での出現順に並べ替えられます。例えば、長さlabels[i,:]=(0,1,1,0,1,3,3,2,2,3)のlabel_length[i]=10の処理されたシーケンスは、unique が true の場合(0,1,3,2)になります。値の範囲: true または false
タイプ:
booleanデフォルト値 : false
必須: いいえ
入力
1:
logits- ロジットのシーケンスのバッチを含む入力テンソル。要素のタイプは T_F です。テンソルの形状は[N, T, C]です。ここで、Nはバッチサイズ、Tは最大シーケンス長、Cはブランクを含むクラスの数です。必須。2:
logit_length- タイプ T1 および形状[N]の 1D 入力テンソル。テンソルは、T以下の負ではない値で構成されなければなりません。ロジットlogits[i,:,:]の入力シーケンスの長さ。必須。3:
labels- T2 タイプの形状[N, T]を持つ 2D テンソル。ターゲットシーケンスlabels[i,:]の長さは、label_length[i]に等しく、blank_indexを除く範囲[0; C-1]が含まれている必要があります。必須。4:
label_length- タイプ T1 および形状[N]の 1D テンソル。テンソルは、すべての可能なiに対してTおよびlabel_length[i] <= logit_length[i]以下の負ではない値で構成されなければなりません。必須。5:
blank_index- T2 タイプのスカラー。空白のラベルに使用するクラスのインデックスを設定します。デフォルト値はC-1です。オプションです。
Output
1: 形状
[N]の出力テンソル、アライメントの対数確率の負の合計。要素のタイプは T_F です。
タイプ
T_F: サポートされている浮動小数点タイプ。
T1、T2:
int32またはint64。
例
<layer ... type="CTCLoss" ...>
<input>
<port id="0">
<dim>8</dim>
<dim>20</dim>
<dim>128</dim>
</port>
<port id="1">
<dim>8</dim>
</port>
<port id="2">
<dim>8</dim>
<dim>20</dim>
</port>
<port id="3">
<dim>8</dim>
</port> <port id="4"> <!-- blank_index 値: 120 -->
</input>
<output>
<port id="0">
<dim>8</dim>
</port>
</output>
</layer>