学習済みモデルの中間層出力である特徴量を取得して、別の学習器に適用したいことってありますよね ?
Kerasでは出力を再定義([probability] -> [probability, features])して model.compile し直すのが普通だと思うのですが、PyTorchではどうなのかと調べてみました。PyTorchでは forward/backward hook でフックするしくみが用意されているのですね。
モデルの定義をいじる必要がなく便利だったので纏めておきます。この記事は 【pytorch】複雑に入り組んでるモデルの中間層の出力を取得したい をベースに実装したものです。
Forward Hook の実装
これはもう決まり事として以下のクラスを用意します。要は盗み見したい中間層出力を常に self.outputs に蓄積し、好みに応じてクリアできるメソッドを提供します。
class SaveOutput: def init(self): self.outputs = []def __call__(self, module, module_in, module_out):
self.outputs.append(module_out.detach())
def clear(self):
self.outputs = []
Forward Hook の登録
たとえばモデル (self.model) を定義するクラスの __init__() で以下のように hook handle を登録します。ここではモデル後段にあるひとつめの全結合層の出力 (self.model.fc[0]) を Hook する想定です。レイヤー名は model.state_dict() などで予め調べておく必要はあります。
self.save_output = SaveOutput()
hook_handles = []
layer = self.model.fc[0]
# whatever you want to snoophandle = layer.register_forward_hook(self.save_output)
hook_handles.append(handle)
結果の取得
たとえばモデルを定義したクラスの prediction メソッドで取得する場合、以下のような実装になります。
class SaveOutput: def init(self): self.outputs = []def __call__(self, module, module_in, module_out):
self.outputs.append(module_out.detach())
def clear(self):
self.outputs = []
class My_Model:def __init__(self):
self.model = <model>.to('cuda:0')
self.transform = transforms.Compose(<transformation definition>)
self.save_output = SaveOutput()
hook_handles = []
layer = self.model.fc[0]
handle = layer.register_forward_hook(self.save_output)
hook_handles.append(handle)
def predict(self, image):
transformed_image = self.transform(image)
result = torch.sigmoid(self.model(transformed_image)).to('cpu').detach().numpy().copy()
features = self.save_output.outputs[0].to('cpu').detach().numpy().copy()
self.save_output.clear()
# clear everytimereturn result, features
if __name__=='__main__': mm = My_Model() while True: image = <get image> result, features = mm.predict(image)
さいごに
スッキリ書けて、結構便利だなと思いました。
あたりまえのことですが、self.model(transformed_image) のあとに self.save_output.outputs を見に行く必要があります。でないと、一番最初のループで実体がないものにアクセスすることになり、エラーになります。グッディーは最初これを逆に書いたために半日はまってしまいました…