引き続きPyTorch LSTMを使って学習したときに発生した問題の紹介です。シンプルなLSTMでは発生していなかったのですが、グラフの規模が大きいシステムを扱ったとたんにハマりました。結果的には初歩的なミスだったのですが…
ここではPyTorch特有のグラフ保存由来の問題に触れますが、ここに来た人は単純にデータセットが大きすぎてメモリー不足の人も多くいらっしゃるような気がしたので、swapの話しも最後に追記しました。
環境
IDEはPyCharmを使用、PyTorchベースのMulti-timescale LSTMを構成し(実験的に)CPU上で学習。
発生した問題
1エポックが終わらないうちに Process finished with exit code 137 (interrupted by signal 9: SIGKILL) で終了。メモリー不足によるエラーとのこと。30GiBのメモリーが約200イテレーションで溢れたことになります…
学習中にfreeコマンドでメモリーを観察していると、たしかに指数関数的にメモリーが消費されてる !
解決への糸口
GPUではないけれども、どうも【PyTorch】不要になった計算グラフを削除してメモリを節約 と同じ原因っぽい。
プログラムは PyTorchでclass_weightを適用するには で用いたものをベースにMulti-timescaleに拡張したもので、表示用にlossをテンソルのまま蓄積する部分がありました。これをやるとどうやらlossのテンソルだけでなく計算グラフ全体がメモリー上に蓄積されてしまうということですね。いかんいかん。
ということで、lossの蓄積にはテンソルそのものではなく数値だけ取り出すように修正。accuracyは純粋なスカラーテンソルですが、気持ち悪いのでついでに同様に修正。めでたくSIGKILLはなくなりました。
model.train() class_weight = [weight_for_0, weight_for_1] for i in range(epochs): losses, acces = [], [] for b, tup in enumerate(dataloader): X, y = tup optimizer.zero_grad() y_pred = model(X) _, p = torch.max(y_pred, 1) _, l = torch.max(y, 1) acc = torchmetrics.functional.accuracy(p, l)acces.append(acc)acces.append(acc.item()) loss = torch.tensor(0.) for j in range(y.shape[1]): loss += loss_function(y_pred[:,j], y[:,j]) * class_weight[j]losses.append(loss)losses.append(loss.item()) loss.backward() optimizer.step() train_loss = sum(losses)/len(losses) train_acc = sum(acces)/len(acces)print('epoch:', i, 'loss:', train_loss.item(), 'acc:', train_acc.item())print('epoch:', i, 'loss:', train_loss, 'acc:', train_acc)
さいごに
結局原因は、lossをテンソルの状態で蓄積したために計算グラフ全体が蓄積されることになり、メモリーがあっという間に枯渇したということです。結果的には初歩的なミスでした。
毎日のようにいろいろハマりますが、解決するたびに深層学習自体の理解が深まった気になれるから不思議です。
ちなみに…
単にデータセットが巨大でCPUメモリーが不足している場合にも exit code 137 (SIGKILL) は発生します。そんなときは、swapでHDDやSSD上に仮想メモリーを確保して対処しましょう。
まず、sudo swapon –show でswapメモリーがすでに設定されているか確認します。グッディーのケース (Ubuntu 20.04.2 LTS) だと、/swapfile がすでに確保されていたので、
cd / sudo swapoff swapfile sudo fallocate -l 60G swapfile sudo mkswap swapfile sudo swapon swapfile
でたとえば60GBの仮想メモリーを確保して対処します。
動作確認は free -h で。プログラムを走らせて、特にデータセットがメモリー上に展開されているときに free -h でどこまでメモリーがひっ迫しているか確認し、必要に応じて容量を追加しましょう。
最近はデータセットの展開もフレームワーク化されて、なかなか思うようにメモリー削減できなかったりしますが、そんなときはあまり Runtime での Augmentation などに固執せずに、swapメモリーで逃げるのも一手ですね。