AI

numpy.reshapeは気をつけて使おう !

KerasからPyTorchへの乗りかえの一環で数日ハマっていたのですがようやく抜け出せました。終わってみれば大したことではなかったのですが、Kerasで一旦動作しているという気のゆるみが誘発したケアレスミスでしたね。こんな情けないミスをする人はそうはいないかも知れませんが、一応まとめておこうと思います。

今回はPyTorchが悪いわけではまったくなかったのですが、PyTorch上でハマったということで…

やろうとしていたこと

Kerasで動作していたLSTMの瞬き検出器をPyTorchでも動作させようとしていました。LSTMではTensorflow nativeでも収束しない問題で結構ハマったことがあり、Tensorflow Kerasに移行したとたんにサクサク学習できたというイケナイ経験もあり、どこかにKerasは賢いという先入観があったのも、今回の基本的なケアレスミスをすぐに探せなかった原因のひとつだった気がします。

ハマった症状

LSTMモデルを作り、目の特徴量のデータセットで学習したところ、過学習気味ではあったものの無事収束。それをビデオやカメラ映像でリアルタイム推定するテストプログラムで試したところ、全く瞬きを検出しない、うーむ。

原因

もともとのデータセットがKeras用の配列(データセットの数 x 目の特徴量の数 x 時間サンプリング数)になっていたのをPyTorch用の配列(データセットの数 x 時間サンプリング数 x 目の特徴量の数)に変換するときに、なんとも気の抜けたことをしていました。

変更前:X_train = np.array(data_list).reshape(len(data_list), dnum, tnum) <- for Keras
変更後:X_train = np.array(data_list).reshape(len(data_list), tnum, dnum) <- for PyTorch

こんなことしちゃダメダメ。正解はこれです。

X_train = np.array(data_list).transpose(0, 2, 1)

わかるまでの道のり

とにもかくにも学習は収束していたので、テストプログラムの学習済みモデルのロード部分を疑いました。あと学習時は512のバッチサイズをテスト時は1にする必要がありますから、そのあたりも疑って学習時のvalidationのバッチサイズを変えて様子を見てみたり。

学習時の収束後のvalidation推定結果では[0.9, 0.1]とか[0.2, 0.8]とかまあまあきちんと見分けてくれているのに、なぜかテストプログラムでInferenceすると[0.5 0.5]あたりをウロウロ。

さんざんモデルまわりを疑ってみたものの症状に変化なく、次にようやくデータ構造に目を向けることになります。

テストプログラムにデータセットのpickleファイル由来のデータを喰わせたところ、正常動作。ということはテストプログラムのデータ構成方法が悪いわけ ? でもリストアペンドしてるだけだし、出来上がったリストはtnum x dnumだからそのままテンソルに変換してPyTorchのモデルに喰わせればよいので間違えようがないし…

ということは学習用のデータが間違ってる ? 収束はしているけど… ということで学習プログラムのデータ生成部分を見直してみると…

このreshape変だな。あれ ? これ転置になっていないのでは… ? ということで転置にして一度収束したモデルに新しいデータで学習を継続してみると… ビンゴ ! [0.5, 0.5]あたりをウロウロしてる。

ということで、転置に直して学習しなおしたところ、テストプログラムでもきちんと瞬きが検出できるようになったとさ。おしまい。

さいごに

今回ハマりにハマった原因として、

  • PyTorchのLSTM実装は初めてで、まずモデル周辺を疑ってしまった
  • (結果的に間違ったデータで)幸運にも学習は収束してしまった
  • 今までもPyTorchでモデルのsave/loadで少なからずハマったことがある
  • 同じデータセットでKerasでは動いていた

このあたりが大きく作用していた気がします。

しかしもしもこれで学習が収束していなかったら、LSTMモデルの構成そのものを疑い続け、原因究明はさらに困難を極めていた可能性があります。結果的には間違った学習をしていたわけですが、学習が収束してくれていたおかげで、この間違いに到達できたのはある意味幸運だったと言えます。

そして思うのは、reshapeがいらないところでも何となくデータ構造が見えるようにreshapeを使う自分の実装。この「何となく」がよろしくなかったです。reshapeやviewを何となく使うよりも、データ構造を意識してtransposeやsqueeze/unsqueezeを駆使したほうが事故は避けられたと反省しています。

とにもかくにも解決してよかった。いつも心のどこかで「Kerasに戻りたい、Tensorflowのxx」と思っている自分がいるので、問題が解決するたびにPyTorchが好きになります。


   
関連記事
  • torch.tensorとtorch.Tensor
  • PyTorch Tutorial その4 – Training Classifier
  • PyTorch LightningのckptファイルをLoadするのにはまった話のその後
  • PyTorch Tutorial その2 – torch.autograd
  • PyTorch LightningのckptファイルをLoadするのにはまった話
  • PyTorchの学習済みmodel fileの拡張子が.pth.tarではまった話

    コメントを残す

    *

    CAPTCHA