torch.squeeze()を使う機会って場当たり的にやってきませんか ? これを計画的に使いこなしている人がいるとしたら、尊敬します。torch.unsqueeze()はまだ必要に応じて使っている感がありますが、torch.squeeze()は「なんか結果が二重大括弧([[tensor, tensor, …]])で出てきたから次元落とそ」的な使い方が個人的にはほとんどだったりします。恥ずかしながら。
よくあるのが、
4 Class の Classification の Trainng 中、たとえば Batch size を 256 にしたとすると、こんな感じで推論結果が出て来ますよね。
$ print(y_pred.shape, y_pred) torch.Size([256, 4]) tensor([[5.9578e-08, 1.3575e-08, 1.0000e+00, 3.8994e-09], [6.7290e-42, 1.1879e-09, 1.1879e-09, 1.0000e+00], [1.0000e+00, 1.4400e-26, 0.0000e+00, 1.7529e-24], ..., [3.6983e-06, 2.8791e-06, 9.9999e-01, 3.9159e-20], [2.6434e-15, 3.3333e-01, 3.3333e-01, 3.3333e-01], [1.1759e-10, 1.9725e-09, 5.0000e-01, 5.0000e-01]], device='cuda:0', grad_fn=<SoftmaxBackward>)
Classification の結果は以下のようになります。
$ _, pidx = torch.max(y_pred, 1) $ print(pidx.shape, pidx) torch.Size([256]) tensor([2, 3, 0, 0, 3, 0, 0, 0, ..., 2, 2, 2, 2, 2, 1, 2]
これに対して Dataloader から提示されるラベルは、このような状態です。
$ print(y.shape, y) torch.Size([256, 1]) tensor([[0], [3], [0], : [3], [3], [2]], device='cuda:0')
これでは Loss や Accuracy を計算できませんので、次元を合わせます。
$ print(y.squeeze().shape, y.squeeze()) torch.Size([256]) tensor([0, 3, 0, 3, 2, 0, 1, 0, ..., 0, 2, 0, 0, 3, 3, 2], device='cuda:0') $ print(y.view(-1).shape, y.view(-1) torch.Size([256]) tensor([0, 3, 0, 3, 2, 0, 1, 0, ..., 0, 2, 0, 0, 3, 3, 2], device='cuda:0')
この場合は torch.squeeze() でも torch.view(-1) でも同じ結果が得られますので、
$ loss = loss_function(y_pred, y.squeeze()) $ acc = torchmetrics.functional.accuracy(pidx, y.squeeze()) $ loss = loss_function(y_pred, y.view(-1)) $ acc = torchmetrics.functional.accuracy(pidx, y.view(-1))
どちらでも正しく計算できます。
しかしながら、
たとえば Validation 時に Batch size を 1 にして検証すると、
$ print(y_pred.shape, y_pred) torch.Size([1, 4]) tensor([[1.5693e-03, 4.1396e-01, 5.8413e-01, 3.3415e-04]], device='cuda:0', grad_fn=<SoftmaxBackward>) $ _, pidx = torch.max(y_pred, 1) $ print(pidx.shape, pidx) torch.Size([1]) tensor([2], device='cuda:0')
Inference の結果はこのようになりますが、それに対して、
$ print(y.shape, y) torch.Size([1, 1]) tensor([[0]], device='cuda:0') $ print(y.squeeze().shape, y.squeeze()) torch.Size([]) tensor(0, device='cuda:0') $ print(y.view(-1).shape, y.view(-1) torch.Size([1]) tensor([0], device='cuda:0')
torch.squeeze() と torch.view(-1) の結果が異なることがわかります。
もう一度、torch.squeeze() の仕様 を確認してみましょう。Tensor の次元のうち、すべてのサイズ 1 の次元を削除する、とあります。
そうなんです。Training 時は Batch size = 256 でしたので y.shape は [256, 1] でしたが、Validation 時は Batch size = 1 にしたので y.shape は [1, 1] になりますから、y.squeeze() は Scalar の Tensor になってしまったわけです。
ではどうしたら Batch size によらず同じ書き方で行けるでしょう ?
$ loss = loss_function(y_pred, y.squeeze(0)) $ acc = torchmetrics.functional.accuracy(pidx, y.squeeze(0)) $ loss = loss_function(y_pred, y.view(-1)) $ acc = torchmetrics.functional.accuracy(pidx, y.view(-1))
そうです、削除する次元を指定すればよいのです。この場合だと、dim = 0 ですね。
まあそうなのですが、
結局次元を意識するのであれば、torch.view() でいいんじゃないの、というところで。torch.squeeze() で削除する次元はひとつしか指定できないですから、複雑なことをしようとすると torch.squeeze() を複数回呼ばないとできないことになります。一方それは torch.view() であれば1回でできることなわけで…
やはり torch.squeeze() は場当たり的に使いたいと思ってしまいますが、そういう使い方をすると、上の例のように、Training 時には動作していたものが、実際の Inference にコピペして使おうとしてもエラーになる、なんていう落とし穴がありますよ、だったら次元を意識して torch.view() を使った方が安心ですね、というお話しでした。