PyTorch Lightningをベースに書かれた画像認識系のソースコードを拡張して自作データセットで学習させたときの苦労話し。
前提
- torch.nn.Modulesでベースモデル (_model) lを定義
- pytorch_lightning.LightningModuleで学習システム (model) を定義
- pytorch_lightning.callbacks.ModelCheckPointでcheckpoint_callbackを定義
- pytorch_lightning.Trainerでcallbackなどの学習条件trainerを定義
- trainer.fit(model)で学習
やりたかったこと
ModelCheckPointで吐かれるckptファイルを使ってテストする。
はまったこと
SAVING AND LOADING WEIGHTS を読むと、シンプルなmodel定義だと model.load_from_checkpoint(ckpt_file_path) で行けるような雰囲気だが、今回のmodelにはdataloaderなど広範囲に定義されていて、いろいろ怒られた。
テストには学習用データセットなどは不要だが、きちんとmodelを定義しないとあれがないこれがないと怒られ断念。
また、_model.load_state_dict(torch.load(ckpt_file_path)) だと、Missing keyとUnexpected keyがずらっと表示されエラーになる。
Missing key(s) in state_dict: “left_features.0.weight”, “left_features.1.weight”, “left_features.1.bias”, …
Unexpected key(s) in state_dict: “epoch”, “global_step”, “pytorch-lightning_version”, …
解決
仕方ないので、ckptファイルのkeysを眺めてみると、こんな感じ:
dict_keys ([“epoch”, “global_step”, “pytorch-lightning_version”, “state_dict“, “callbacks”, “optimizer_states”, “lr_schedulers”, “hparams_name”, “hyper_parameters”])
‘state_dict’はあるので、_model.load_state_dict(torch.load(ckpt file path)[‘state_dict’]) を試してみると、またまたrequired keys とunexpected keysがずらっと並ぶが、さっきと違うのは、ほぼ同じkeysが並んでいること。
Missing key(s) in state_dict: “left_features.0.weight”, “left_features.1.weight”, “left_features.1.bias”, …
Unexpected key(s) in state_dict: “_model.left_features.0.weight”, “_model.left_features.1.weight”, “_model.left_features.1.bias”, …
Missing keys & unexpected keys in state_dict when loading self trained model によると、引数にstrict=Falseを加えると厳密な比較を行わないらしく、最終的には _model.load_state_dict(torch.load(ckpt file path)[‘state_dict’], strict=False) で無事ロードできた !
おわりに
こんなんでええんかいなという気持ちは拭えませんが、今はいろいろ試しながら理解を深めていこうと思います。