バランスの取れたデータセットが常に用意できればよいのですが、ラベルごとにデータの数がばらついてしまうことはよくあります。
そんなとき、Kerasだとclass_weightを用意してmodel.fit()の引数に適用すればよかったので悩むこともなかったのですが、PyTorchの場合基本的にはLoss関数の引数で定義すればよいようです。Loss関数のリストはこちら(Loss Functions)。
ただしすべてのLoss関数にweight引数が用意されているわけではなく、nn.NLLLoss
, nn.CrossEntropyLoss, nn.BCELoss, nn.BCEWithLogitsLoss
と言ったよく使われるLoss関数に限られるようです。
ではほかのLoss関数を利用したい場合はどうしたらよいのか、というのが事の発端でした。
やりたかったこと
瞬きの検出をLSTMで構成していました。仕様の詳細は省略しますが、ラベルとしては1 : 瞬き、0 : それ以外というシンプルなものです。ラベルの数は1がかなり少なくなってしまうことが容易に想像できると思います。
これに対して、アンダーサンプリング / オーバーサンプリングと言ったデータセットの操作による解決方法があります。これも有用な手法ですが、以前使用していたKerasではclass_weightを利用していました。
ここでは以下を前提とします。
- negative : ラベル0の数
- positive : ラベル1の数
- weight_for_0 : 1. / negative * (negative + positive)
- weight_for_1 : 1. / positive * (negative + positive)
- class_weight = {0 : weight_for_0, 1 : weight_for_1}
Kerasの場合
model.fit()の引数にclass_weightを定義することができました。たとえば以下のような感じです。
hist = model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCKS, validation_data=(x_test, y_test), callbacks=[early_stopping, checkpoint], class_weight=class_weight)
つまりLoss関数如何によらずclass_weightを定義することができたので、悩むことがなかったのです。
PyTorchでの解決策
今回のケースでは nn.L1Loss を使っていましたが、引数に weight はありません。さてどうしようかと思ったところで参考にしたのは Dealing with imbalanced datasets in pytorch です。こちらによれば、loss を計算する際に weight を掛けて解決することを推奨しています。
ということで以下のように解決しました(抜粋)。
model.train() class_weight = [weight_for_0, weight_for_1] for i in range(epochs): losses, acces = [], [] for b, tup in enumerate(dataloader): X, y = tup optimizer.zero_grad() y_pred = model(X) _, p = torch.max(y_pred, 1) _, l = torch.max(y, 1) acc = torchmetrics.functional.accuracy(p, l) acces.append(acc) loss = torch.tensor(0.) for j in range(y.shape[1]): loss += loss_function(y_pred[:,j], y[:,j]) * class_weight[j] losses.append(loss) loss.backward() optimizer.step() train_loss = sum(losses)/len(losses) train_acc = sum(acces)/len(acces) print('epoch:', i, 'loss:', train_loss.item(), 'acc:', train_acc.item())
このコードだとメモリ消費が著しいことが発覚しました。修正版はこちら。
結果
今回のケースでは、
- class_weightを導入したほうが、学習速度が圧倒的に早かった
- 256エポックでの収束値のAccuracyは双方とも同程度の0.95だった
という結果でした。
さいごに
PyTorchでlossの計算にclass_weightを適用し、少なくとも学習速度に非常に効果的であることがわかりました。
2クラス分類問題だと出力数は1でも0/1で分類できますが、class_weightを適用するためには出力数を2にする必要があります。また3クラス以上の多クラス分類にも本手法が容易に適用できることは明らかですよね。
今回のケースでは最終的なAccuracyはweight_classを導入しなくても同程度を達成できましたが、複雑な問題になればなるほどこのclass_weightの効果は重要になってくると考えられます。
以上、PyTorchでのclass_weightの適用でした。