機械学習のExampleから覚えるPython(class:__call__)

今までPythonを感覚的に使っていたので、改めて文法を知ろうかなと。
その際にいま流行りの機械学習(深層学習)のExampleを例にすると
わかりやすいのかなと思ったので書いてみる。

※基本的には Python3.x系のつもりで記載してます

などで、class について記載したのですが、
今回はその中でも独自モデル(Custom Model)を作る場合の記述について見てみます。

参考:https://www.tensorflow.org/guide/keras

Example

class MyModel(tf.keras.Model):

  def __init__(self, num_classes=10):
    super(MyModel, self).__init__(name='my_model')
    self.num_classes = num_classes
    # Define your layers here.
    self.dense_1 = layers.Dense(32, activation='relu')
    self.dense_2 = layers.Dense(num_classes, activation='sigmoid')

  def call(self, inputs):
    # Define your forward pass here,
    # using layers you previously defined (in `__init__`).
    x = self.dense_1(inputs)
    return self.dense_2(x)

  def compute_output_shape(self, input_shape):
    # You need to override this function if you want to use the subclassed model
    # as part of a functional-style model.
    # Otherwise, this method is optional.
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.num_classes
    return tf.TensorShape(shape)

解説

class MyModel(tf.keras.Model):

tf.keras.Modelを継承

  def __init__(self, num_classes=10):
    super(MyModel, self).__init__(name='my_model')
    self.num_classes = num_classes
    # Define your layers here.
    self.dense_1 = layers.Dense(32, activation='relu')
    self.dense_2 = layers.Dense(num_classes, activation='sigmoid')

コンストラクタ(初期化:__init__)にて、レイヤをインスタンス

  def call(self, inputs):
    # Define your forward pass here,
    # using layers you previously defined (in `__init__`).
    x = self.dense_1(inputs)
    return self.dense_2(x)

call関数で、forwardの動作を記載します。
さて、なぜcall関数で?という疑問になりますよね?

tf.keras.Modelをおっていくとbase_layer.pyまでいきつきます。

base_layerにある __call__関数で、上記の callがよばれることになります。

__call__ は組込み関数でクラスをインスタンスして、関数のように呼ばれたときに動きます。

例でいうと、以下のようになります。

model = MyModel

model() # この時に __call__ →...→ call

今回、TensorFlow/Keras を例にしてますが、PyTorchやChainerでも__call__にて
Model or Layerを構築しています。

forwardbackward関数で実装というのもあるので、
フレームワークにしたがって書きましょう。

  def compute_output_shape(self, input_shape):
    # You need to override this function if you want to use the subclassed model
    # as part of a functional-style model.
    # Otherwise, this method is optional.
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.num_classes
    return tf.TensorShape(shape)

コメントにあるとおり、書いてもいいし。というやつです。

まとめ

今回、callからの__call__についてでした。