AI

PyTorch LightningのckptファイルをLoadするのにはまった話のその後

やはり元の話しでの解決方法は、辞書の中身の違いを分析し、ある意味無理やり合わせることによって解決したのでスッキリしていないのと、エラーは出なくなったけどもパラメータがきちんとロードされていない場合も見受けられたので、もう少し調べてみた。

糸口

SAVING AND LOADING WEIGHTSCheckpoint loading のところに、

# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

とある。引数を持たせたり、変えたりいろいろできるのね。ふむふむ。

グッディーが今回参考にさせて頂いているモデルは、以下のような感じで定義されていて、(中身はかなり省略)

class TrainClassifier(pytorch_lightning.LightningModule):
    def __init__(self, hparams, train_subjects, validate_subjects, class_weights=None):
        super(TrainClassifier, self).__init__()

        self._model = VGG16 based model
        self._criterion = binary cross entropy with pos_weights=class_weights
        self._train_subjects = train_subjects
        self._validate_subjects = validate_subjects
        self.hparams = hparams

    def forward(self, inputs):
        ...
    def training_step(self, batch, batch_idx):
        ...
    def validation_step(self, batch, batch_idx):
        ...
    def validation_epoch_end(self, outputs):
        ...
    def configure_optimizers(self):
        ...
    def train_dataloader(self):
        ...
    def val_dataloader(self):
        ...

以下のようなコードでトレーニングします。

from pytorch_lightning import Trainer
from fytorch_lightning.callbacks import ModelCheckpoint

model = TrainClassifier(hparams, train_subjects, validate_subjects, class_weights)
checkpoint_callback = ModelCheckpoint(...)
trainer = Trainer(...)
trainer.fit(model)

トレーニングの再開

このプログラムでセーブされたckpt fileを再びロードしてトレーニングを継続するには、

model = TrainClassifier.load_from_checkpoint('ckpt file path', hparams, train_subjects, validate_subjects, class_weights)
checkpoint_callback = ModelCheckpoint(...)
trainer = Trainer(...)
trainer.fit(model)

でもできたし、SAVING AND LOADING WEIGHTSRestoring Training State にあるように、

model = TrainClassifier(hparams, train_subjects, validate_subjects, class_weights)
trainer = Trainer(resume_from_checkpoint='checkpoint file path')
trainer.fit(model)

でもできた。うん、たしかにこれは便利。

テストへの利用

グッディーはとりあえず TrainClassifier のサブセット(テストに不要なdataloaderなどを削除したもの)として以下のようなテスト用の EvalClassifier を作りました。

class EvalClassifier(pytorch_lightning.LightningModule):
    def __init__(self, device_id):
        super(EvalClassifier, self).__init__()

        self._device_id = device_id
        self._model = VGG16 based model .to(self.device_id)
        self._criterion = binary cross entropy with pos_weights=class_weights
        self._transform = transforms.Compose(...)
        ...

ここで loss_fn が残ってしまっているのは、ckpt file の辞書データの中の ‘state_dict’ に pos_weights があったため、ダミーで残しました。(もっといい方法があるかも知れない)

transform はこのモデルに喰わすデータ専用の変換。こういったこのモデルでのテスト専用のオブジェクトやメソッドをここに定義していくのもいいのかも知れない。で、

classifier = EvalClassifier.load_from_checkpoint('ckpt file path', device_id='cuda:0')
estimated_class = classifier(input_image)

これで無事動作しました。

おわりに

前回よりはスッキリする結果が得られた気はしています。が、loss_fn の pos_weights のようなものが ‘state_dict’ に含まれている場合は、ちょっとモヤモヤが残ります。

もちろん’state_dict’を直接操作しても良いのだけれどちょっと強引すぎるし、Trainerに ckpt file を指示して再開できるのは便利だし、まあそういうものということにします。


   
関連記事
  • Ubuntu20.04+CUDA11.2+PyTorch1.8.1+cu111がRTX3090のGPUで動作した !
  • PyTorchでclass_weightを適用するには
  • torch.tensorとtorch.Tensor
  • torch.viewも気をつけて使おう !
  • Torchvisionのtransforms.Composeを使いこなしてtraining accuracyを上げよう !
  • PyTorchで学習済みモデルの中間層出力の取得

    コメントを残す

    *

    CAPTCHA