AI

PyTorchの学習済みmodel fileの拡張子が.pth.tarではまった話

ども、グッディーです。
PyTorch初心者ならではのはまりどころをご紹介していくcorner。

今回は視線推定の論文を読んでいて、dataset, source code, 学習済みmodelが上がっていたので試していたときのこと。

PyTorchでmodel fileとして.pthが使われるのは知っていました。そして今回のmodel file名がxxxxx.pth.tarだったので、脊髄反射的に”tar xvf”や7-zipで解凍を試みるもできず。

またuploadしている風で実は使えないやつかー、とか思いながら、念のため”.pth.tar”で調べてみると、SAVING AND LOADING MODELSに:

When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layers, etc. As a result, such a checkpoint is often 2~3 times larger than the model alone.

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.

あるいはこちらも:

When saving a model comprised of multiple torch.nn.Modules, such as a GAN, a sequence-to-sequence model, or an ensemble of models, you follow the same approach as when you are saving a general checkpoint. In other words, save a dictionary of each model’s state_dict and corresponding optimizer. As mentioned before, you can save any other items that may aid you in resuming training by simply appending them to the dictionary.

A common PyTorch convention is to save these checkpoints using the .tar file extension.

と書かれていました。

PyTorchの推奨は、state_dictのsave/loadです。

torch.save(model.state_dict(), PATH)

しかし、学習の中断/再開を行う場合や、GAN, sequence-to-sequenceのように複数moduleを扱う場合など、単一のstate_dictだけではなく他のitemも辞書dataに追加したい場合があります。

torch.save({
‘modelA_state_dict’: modelA.state_dict(),
‘modelB_state_dict’: modelB.state_dict(),
‘optimizerA_state_dict’: optimizerA.state_dict(),
‘optimizerB_state_dict’: optimizerB.state_dict(),
}, PATH)

PtTorchの慣習として、後者の場合は.tarの拡張子をつけましょう、とのこと。

そりゃ解凍できないわ。


   
関連記事
  • Ubuntu20.04+CUDA11.2+PyTorch1.8.1+cu111がRTX3090のGPUで動作した !
  • numpy.reshapeは気をつけて使おう !
  • PyTorchで突然malloc(): invalid next size (unsorted)が出たときの対処
  • torch.squeeze()よりtorch.view()が安心だった話し
  • Pytorch Tutorial その1 – Tensor
  • PyTorch LightningのckptファイルをLoadするのにはまった話

    コメントを残す

    *

    CAPTCHA