AI

PyTorchで学習済みモデルの中間層出力の取得

学習済みモデルの中間層出力である特徴量を取得して、別の学習器に適用したいことってありますよね ?

Kerasでは出力を再定義([probability] -> [probability, features])して model.compile し直すのが普通だと思うのですが、PyTorchではどうなのかと調べてみました。PyTorchでは forward/backward hook でフックするしくみが用意されているのですね。

モデルの定義をいじる必要がなく便利だったので纏めておきます。この記事は 【pytorch】複雑に入り組んでるモデルの中間層の出力を取得したい をベースに実装したものです。

Forward Hook の実装

これはもう決まり事として以下のクラスを用意します。要は盗み見したい中間層出力を常に self.outputs に蓄積し、好みに応じてクリアできるメソッドを提供します。

class SaveOutput:
  def init(self):
     self.outputs = []
  def __call__(self, module, module_in, module_out): 
    self.outputs.append(module_out.detach()) 
  def clear(self): 
    self.outputs = []

Forward Hook の登録

たとえばモデル (self.model) を定義するクラスの __init__() で以下のように hook handle を登録します。ここではモデル後段にあるひとつめの全結合層の出力 (self.model.fc[0]) を Hook する想定です。レイヤー名は model.state_dict() などで予め調べておく必要はあります。

  self.save_output = SaveOutput() 
  hook_handles = [] 
  layer = self.model.fc[0]     # whatever you want to snoop
  handle = layer.register_forward_hook(self.save_output) 
  hook_handles.append(handle)

結果の取得

たとえばモデルを定義したクラスの prediction メソッドで取得する場合、以下のような実装になります。

class SaveOutput:
  def init(self):
     self.outputs = []
  def __call__(self, module, module_in, module_out): 
    self.outputs.append(module_out.detach()) 
  def clear(self): 
    self.outputs = []

class My_Model:
  def __init__(self): 
    self.model = <model>.to('cuda:0') 
    self.transform = transforms.Compose(<transformation definition>) 

    self.save_output = SaveOutput() 
    hook_handles = [] 
    layer = self.model.fc[0] 
    handle = layer.register_forward_hook(self.save_output) 
    hook_handles.append(handle) 

  def predict(self, image): 
    transformed_image = self.transform(image) 
    result = torch.sigmoid(self.model(transformed_image)).to('cpu').detach().numpy().copy() 

    features = self.save_output.outputs[0].to('cpu').detach().numpy().copy() 
    self.save_output.clear()    # clear everytime

    return result, features

if __name__=='__main__':
  mm = My_Model()
  while True:
    image = <get image>
    result, features = mm.predict(image)

さいごに

スッキリ書けて、結構便利だなと思いました。

あたりまえのことですが、self.model(transformed_image) のあとに self.save_output.outputs を見に行く必要があります。でないと、一番最初のループで実体がないものにアクセスすることになり、エラーになります。グッディーは最初これを逆に書いたために半日はまってしまいました…


   
関連記事
  • PyTorchでclass_weightを適用するには
  • PyTorch Tutorial その3 – Neural Network
  • PyTorch Tutorial その4 – Training Classifier
  • Ubuntu20.04+CUDA11.2+PyTorch1.8.1+cu111がRTX3090のGPUで動作した !
  • Pytorch Tutorial その1 – Tensor
  • PyTorch Tutorial その2 – torch.autograd

    コメントを残す

    *

    CAPTCHA