AI

torch.tensorとtorch.Tensor

新参者は歴史的な経緯を知らないので少しはまりました。Floatゼロ値のスカラーテンソルを作りたかっただけなのですが、torch.Tensor(0.) だとダメで正解は torch.tensor(0.) だったというお話しです。

What is the difference between torch.tensor and torch.Tensor? を読んで違いが何となくわかりましたので、翻訳しておこうかと思います。ただし記事としては3年前の古いものである前提でお読みください。

Question

Version 0.4.0 からtorch.tensorとtorch.Tensorが使えるようになりました。

違いは何でしょう ? どういう背景でこのような見分けがつきにくく混乱しそうな 2 手法が提供されたのでしょうか ?

5 Answers

Answer 1

PyTorch では torch.Tensor がメインのテンソルクラスです。よってすべてのテンソルはtorch.Tensor のインスタンスになります。

torch.Tensor() を呼ぶと、データのない空のテンソルができます。

一方 torch.tensor はテンソルを返す関数です。ドキュメントによると:

torch.tensor (data, dtype=None, device=None, requires_grad=False) → Tensor
データ付きのテンソルを生成します。

また、以下のようにデータなしで torch.Tensor を呼ぶことによって空のテンソルのインスタンスを生成できます。

tensor_without_data = torch.Tensor()

しかし、

tensor_without_data = torch.tensor()

はエラーになります。

—————————————————————————
TypeError Traceback (most recent call last)
<ipython-input-12-ebc3ceaa76d2> in <module>()
—-> 1 torch.tensor()

TypeError: tensor() missing 1 required positional arguments: “data”

まあ torch.tensor の代わりに torch.Tensor を使う理由はないでしょう。torch.Tensor は説明も少ないですしね。

似た方法で空のテンソルを作るとすれば、

torch.tensor(())

によって以下が生成されます。

tensor([])

Answer-2

pytorch discussion  での議論によれば、torch.Tensor コンストラクターは torch.tensor と torch.empty を合わせたようなものです。この類似性は混乱のもとなので、torch.tensor と torch.empty を使いましょう。

つまり、torch.tensor は torch.Tensor(<data>) と同様のふるまいをします。どちらの方が優れているということはありませんが、torch.empty と torch.tensor の方がナイスな API ですね。

Answer-3

Answer-2に追加で気付いたのですが、torch.Tensor(<data>) は torch.get_default_dtype() で定義された default data type でテンソルに初期値が与えられますが、一方で torch.tensor(<data>) は data 自体から data type を推測して設定しますね。

例えば、

torch.tensor ([[2, 5, 6], [9, 7, 6]])

は、tensor([[2, 5, 6], [9, 7, 6]]) を出力し、

torch.Tensor ([[2, 5, 6], [9, 7, 6]])

は、default data type が float32 に設定されていたら、tensor([[2., 5., 6.], [9., 7., 6.]]) を出力します。

Answer-4

Difference between torch.tensor and torch.Tensor

torch.tensor は dtype を自動的に推定しますが、torch.Tensor は torch.FloatTensor を返します。しかし torch.tensor にも dtype のような引数があるので、torch.tensor の使用をお薦めします。

Answer-5

torch.Tensor は、例えば nn.Linear や nn._ConvNd でパラメーターを生成するときに好んで使います。なぜって ? とても速いからです。torch.empty() よりもちょびっと速いです。

import torch
torch.set_default_dtype(torch.float32) # default
%timeit torch.empty(1000,1000)
%timeit torch.Tensor(1000,1000)
%timeit torch.ones(1000,1000)
%timeit torch.tensor([[1]*1000]*1000)

出力:

68.4 µs ± 789 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
67.9 µs ± 349 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.26 ms ± 8.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
36.1 ms ± 610 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

torch.Tensor() と torch.empty() はほぼ同じで空の(初期値を持たないデータが入った)テンソルを返します。

では __init__ でパラメータの初期化を設定することは技術的には可能でしょうか ?以下に実際の nn.Linear の重みパラメーターの torch.Tensor を示します。

self.weight = nn.Parameter( torch.Tensor (out_features, in_features) )

これを個別に初期化はしません。トレーニング中にパラメーターのリセットを行うかも知れませんから、reset_parameters() メソッドは別にあります。もちろん__init__() メソッドの最後に reset_parameters() を呼ぶことも可能です。

恐らく torch.empty() は torch.Tensor() に取って代わるでしょう。もうひとつオプションがあるとすれば、それは自分の reset_parameters() を書くことです。

さいごに

こちらも古い記事にはなりますが、pytorch 0.4の変更点 を読むと、このときPyTorchはかなり劇的に変わったのですね。ゼロ次元テンソルもこのときから使えるようになったようなので、さまざまな背景があったのだなと理解できます。

グッディーは Keras では Dense() などの引数として用意されている regularizer が PyTorch にはなかったので自分で書くことになったわけですが、そのときにゼロ値のスカラーテンソルが必要になりました。なので、torch.tensor(0.) と書けるのはとても便利 !

ちなみに、左を print すると右になります(初期値は変わります)。

  • torch.tensor(0.) -> tensor(0.)
  • torch.tensor(()) -> tensor([])
  • torch.Tensor(0.) -> TypeError: new() : data must be a sequence (got float)
  • torch.Tensor(0) -> tensor([])
  • torch.Tensor(1) -> tensor([1.8990e+28])
  • torch.Tensor(2) -> tensor([1.0899e-27, 0.])

torch.Tensor でゼロ値のスカラーテンソルを無理やり書くとしたら、*torch.Tensor([0]) でしょうかねえ。

ということで、torch.tensor と torch.Tensor でした。スッキリ !?


   
関連記事
  • Torchvisionのtransforms.Composeを使いこなしてtraining accuracyを上げよう !
  • torch.squeeze()よりtorch.view()が安心だった話し
  • PyTorchで突然malloc(): invalid next size (unsorted)が出たときの対処
  • Ubuntu20.04+CUDA11.2+PyTorch1.8.1+cu111がRTX3090のGPUで動作した !
  • PyTorch LightningのckptファイルをLoadするのにはまった話
  • PyTorch Tutorial その3 – Neural Network

    コメントを残す

    *

    CAPTCHA