やはり元の話しでの解決方法は、辞書の中身の違いを分析し、ある意味無理やり合わせることによって解決したのでスッキリしていないのと、エラーは出なくなったけどもパラメータがきちんとロードされていない場合も見受けられたので、もう少し調べてみた。
糸口
SAVING AND LOADING WEIGHTS の Checkpoint loading のところに、
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
とある。引数を持たせたり、変えたりいろいろできるのね。ふむふむ。
グッディーが今回参考にさせて頂いているモデルは、以下のような感じで定義されていて、(中身はかなり省略)
class TrainClassifier(pytorch_lightning.LightningModule): def __init__(self, hparams, train_subjects, validate_subjects, class_weights=None): super(TrainClassifier, self).__init__() self._model = VGG16 based model self._criterion = binary cross entropy with pos_weights=class_weights self._train_subjects = train_subjects self._validate_subjects = validate_subjects self.hparams = hparams def forward(self, inputs): ... def training_step(self, batch, batch_idx): ... def validation_step(self, batch, batch_idx): ... def validation_epoch_end(self, outputs): ... def configure_optimizers(self): ... def train_dataloader(self): ... def val_dataloader(self): ...
以下のようなコードでトレーニングします。
from pytorch_lightning import Trainer from fytorch_lightning.callbacks import ModelCheckpoint model = TrainClassifier(hparams, train_subjects, validate_subjects, class_weights) checkpoint_callback = ModelCheckpoint(...) trainer = Trainer(...) trainer.fit(model)
トレーニングの再開
このプログラムでセーブされたckpt fileを再びロードしてトレーニングを継続するには、
model = TrainClassifier.load_from_checkpoint('ckpt file path', hparams, train_subjects, validate_subjects, class_weights) checkpoint_callback = ModelCheckpoint(...) trainer = Trainer(...) trainer.fit(model)
でもできたし、SAVING AND LOADING WEIGHTS の Restoring Training State にあるように、
model = TrainClassifier(hparams, train_subjects, validate_subjects, class_weights) trainer = Trainer(resume_from_checkpoint='checkpoint file path') trainer.fit(model)
でもできた。うん、たしかにこれは便利。
テストへの利用
グッディーはとりあえず TrainClassifier のサブセット(テストに不要なdataloaderなどを削除したもの)として以下のようなテスト用の EvalClassifier を作りました。
class EvalClassifier(pytorch_lightning.LightningModule): def __init__(self, device_id): super(EvalClassifier, self).__init__() self._device_id = device_id self._model = VGG16 based model .to(self.device_id) self._criterion = binary cross entropy with pos_weights=class_weights self._transform = transforms.Compose(...) ...
ここで loss_fn が残ってしまっているのは、ckpt file の辞書データの中の ‘state_dict’ に pos_weights があったため、ダミーで残しました。(もっといい方法があるかも知れない)
transform はこのモデルに喰わすデータ専用の変換。こういったこのモデルでのテスト専用のオブジェクトやメソッドをここに定義していくのもいいのかも知れない。で、
classifier = EvalClassifier.load_from_checkpoint('ckpt file path', device_id='cuda:0') estimated_class = classifier(input_image)
これで無事動作しました。
おわりに
前回よりはスッキリする結果が得られた気はしています。が、loss_fn の pos_weights のようなものが ‘state_dict’ に含まれている場合は、ちょっとモヤモヤが残ります。
もちろん’state_dict’を直接操作しても良いのだけれどちょっと強引すぎるし、Trainerに ckpt file を指示して再開できるのは便利だし、まあそういうものということにします。