AI

PyTorch LightningのckptファイルをLoadするのにはまった話

PyTorch Lightningをベースに書かれた画像認識系のソースコードを拡張して自作データセットで学習させたときの苦労話し。

前提

  • torch.nn.Modulesでベースモデル (_model) lを定義
  • pytorch_lightning.LightningModuleで学習システム (model) を定義
  • pytorch_lightning.callbacks.ModelCheckPointでcheckpoint_callbackを定義
  • pytorch_lightning.Trainerでcallbackなどの学習条件trainerを定義
  • trainer.fit(model)で学習

やりたかったこと

ModelCheckPointで吐かれるckptファイルを使ってテストする。

はまったこと

SAVING AND LOADING WEIGHTS を読むと、シンプルなmodel定義だと model.load_from_checkpoint(ckpt_file_path) で行けるような雰囲気だが、今回のmodelにはdataloaderなど広範囲に定義されていて、いろいろ怒られた。

テストには学習用データセットなどは不要だが、きちんとmodelを定義しないとあれがないこれがないと怒られ断念。

また、_model.load_state_dict(torch.load(ckpt_file_path)) だと、Missing keyとUnexpected keyがずらっと表示されエラーになる。

Missing key(s) in state_dict: “left_features.0.weight”, “left_features.1.weight”, “left_features.1.bias”, …
Unexpected key(s) in state_dict: “epoch”, “global_step”, “pytorch-lightning_version”, …

解決

仕方ないので、ckptファイルのkeysを眺めてみると、こんな感じ:

dict_keys ([“epoch”, “global_step”, “pytorch-lightning_version”, “state_dict“, “callbacks”, “optimizer_states”, “lr_schedulers”, “hparams_name”, “hyper_parameters”])

‘state_dict’はあるので、_model.load_state_dict(torch.load(ckpt file path)[‘state_dict’]) を試してみると、またまたrequired keys とunexpected keysがずらっと並ぶが、さっきと違うのは、ほぼ同じkeysが並んでいること。

Missing key(s) in state_dict: “left_features.0.weight”, “left_features.1.weight”, “left_features.1.bias”, …
Unexpected key(s) in state_dict: “_model.left_features.0.weight”, “_model.left_features.1.weight”, “_model.left_features.1.bias”, …

Missing keys & unexpected keys in state_dict when loading self trained model によると、引数にstrict=Falseを加えると厳密な比較を行わないらしく、最終的には _model.load_state_dict(torch.load(ckpt file path)[‘state_dict’], strict=False) で無事ロードできた !

おわりに

こんなんでええんかいなという気持ちは拭えませんが、今はいろいろ試しながら理解を深めていこうと思います。


   
関連記事
  • torch.tensorとtorch.Tensor
  • PyTorch LightningのckptファイルをLoadするのにはまった話のその後
  • numpy.reshapeは気をつけて使おう !
  • The NVIDIA driver on your system is too oldって !!!
  • PyTorchで突然malloc(): invalid next size (unsorted)が出たときの対処
  • Torchvisionのtransforms.Composeを使いこなしてtraining accuracyを上げよう !

    コメントを残す

    *

    CAPTCHA