AI

torch.squeeze()よりtorch.view()が安心だった話し

torch.squeeze()を使う機会って場当たり的にやってきませんか ? これを計画的に使いこなしている人がいるとしたら、尊敬します。torch.unsqueeze()はまだ必要に応じて使っている感がありますが、torch.squeeze()は「なんか結果が二重大括弧([[tensor, tensor, …]])で出てきたから次元落とそ」的な使い方が個人的にはほとんどだったりします。恥ずかしながら。

よくあるのが、

4 Class の Classification の Trainng 中、たとえば Batch size を 256 にしたとすると、こんな感じで推論結果が出て来ますよね。

$ print(y_pred.shape, y_pred)
torch.Size([256, 4]) 
tensor([[5.9578e-08, 1.3575e-08, 1.0000e+00, 3.8994e-09],
        [6.7290e-42, 1.1879e-09, 1.1879e-09, 1.0000e+00],
        [1.0000e+00, 1.4400e-26, 0.0000e+00, 1.7529e-24],
        ...,
        [3.6983e-06, 2.8791e-06, 9.9999e-01, 3.9159e-20],
        [2.6434e-15, 3.3333e-01, 3.3333e-01, 3.3333e-01],
        [1.1759e-10, 1.9725e-09, 5.0000e-01, 5.0000e-01]], device='cuda:0', grad_fn=<SoftmaxBackward>) 

Classification の結果は以下のようになります。

$ _, pidx = torch.max(y_pred, 1) 
$ print(pidx.shape, pidx)
torch.Size([256])
tensor([2, 3, 0, 0, 3, 0, 0, 0, ..., 2, 2, 2, 2, 2, 1, 2]

これに対して Dataloader から提示されるラベルは、このような状態です。

$ print(y.shape, y)
torch.Size([256, 1]) tensor([[0],
        [3],
        [0],
         :
        [3],
        [3],
        [2]], device='cuda:0')

これでは Loss や Accuracy を計算できませんので、次元を合わせます。

$ print(y.squeeze().shape, y.squeeze())
torch.Size([256])
tensor([0, 3, 0, 3, 2, 0, 1, 0, ..., 0, 2, 0, 0, 3, 3, 2], device='cuda:0')

$ print(y.view(-1).shape, y.view(-1)
torch.Size([256])
tensor([0, 3, 0, 3, 2, 0, 1, 0, ..., 0, 2, 0, 0, 3, 3, 2], device='cuda:0')

この場合は torch.squeeze() でも torch.view(-1) でも同じ結果が得られますので、

$ loss = loss_function(y_pred, y.squeeze())
$ acc = torchmetrics.functional.accuracy(pidx, y.squeeze())

$ loss = loss_function(y_pred, y.view(-1))
$ acc = torchmetrics.functional.accuracy(pidx, y.view(-1))

どちらでも正しく計算できます。

しかしながら、

たとえば Validation 時に Batch size を 1 にして検証すると、

$ print(y_pred.shape, y_pred)
torch.Size([1, 4])
tensor([[1.5693e-03, 4.1396e-01, 5.8413e-01, 3.3415e-04]], device='cuda:0', grad_fn=<SoftmaxBackward>) 

$ _, pidx = torch.max(y_pred, 1) 
$ print(pidx.shape, pidx)
torch.Size([1])
tensor([2], device='cuda:0')

Inference の結果はこのようになりますが、それに対して、

$ print(y.shape, y)
torch.Size([1, 1])
tensor([[0]], device='cuda:0')

$ print(y.squeeze().shape, y.squeeze())
torch.Size([])
tensor(0, device='cuda:0')

$ print(y.view(-1).shape, y.view(-1)
torch.Size([1])
tensor([0], device='cuda:0')

torch.squeeze() と torch.view(-1) の結果が異なることがわかります。

もう一度、torch.squeeze() の仕様 を確認してみましょう。Tensor の次元のうち、すべてのサイズ 1 の次元を削除する、とあります。
そうなんです。Training 時は Batch size = 256 でしたので y.shape は [256, 1] でしたが、Validation 時は Batch size = 1 にしたので y.shape は [1, 1] になりますから、y.squeeze() は Scalar の Tensor になってしまったわけです。

ではどうしたら Batch size によらず同じ書き方で行けるでしょう ?

$ loss = loss_function(y_pred, y.squeeze(0))
$ acc = torchmetrics.functional.accuracy(pidx, y.squeeze(0))

$ loss = loss_function(y_pred, y.view(-1))
$ acc = torchmetrics.functional.accuracy(pidx, y.view(-1))

そうです、削除する次元を指定すればよいのです。この場合だと、dim = 0 ですね。

まあそうなのですが、

結局次元を意識するのであれば、torch.view() でいいんじゃないの、というところで。torch.squeeze() で削除する次元はひとつしか指定できないですから、複雑なことをしようとすると torch.squeeze() を複数回呼ばないとできないことになります。一方それは torch.view() であれば1回でできることなわけで…

やはり torch.squeeze() は場当たり的に使いたいと思ってしまいますが、そういう使い方をすると、上の例のように、Training 時には動作していたものが、実際の Inference にコピペして使おうとしてもエラーになる、なんていう落とし穴がありますよ、だったら次元を意識して torch.view() を使った方が安心ですね、というお話しでした。


   
関連記事
  • PyTorch LightningのckptファイルをLoadするのにはまった話
  • Torchvisionのtransforms.Composeを使いこなしてtraining accuracyを上げよう !
  • PyTorch LightningのckptファイルをLoadするのにはまった話のその後
  • PyTorchでclass_weightを適用するには
  • kernel_initializerって学習の収束に大事かも
  • torch.viewも気をつけて使おう !

    コメントを残す

    *

    CAPTCHA