機械学習のExampleから覚えるPython(class:__call__)
今までPythonを感覚的に使っていたので、改めて文法を知ろうかなと。
その際にいま流行りの機械学習(深層学習)のExampleを例にすると
わかりやすいのかなと思ったので書いてみる。
※基本的には Python3.x系のつもりで記載してます
などで、class
について記載したのですが、
今回はその中でも独自モデル(Custom Model)を作る場合の記述について見てみます。
参考:https://www.tensorflow.org/guide/keras
Example
class MyModel(tf.keras.Model): def __init__(self, num_classes=10): super(MyModel, self).__init__(name='my_model') self.num_classes = num_classes # Define your layers here. self.dense_1 = layers.Dense(32, activation='relu') self.dense_2 = layers.Dense(num_classes, activation='sigmoid') def call(self, inputs): # Define your forward pass here, # using layers you previously defined (in `__init__`). x = self.dense_1(inputs) return self.dense_2(x) def compute_output_shape(self, input_shape): # You need to override this function if you want to use the subclassed model # as part of a functional-style model. # Otherwise, this method is optional. shape = tf.TensorShape(input_shape).as_list() shape[-1] = self.num_classes return tf.TensorShape(shape)
解説
class MyModel(tf.keras.Model):
tf.keras.Model
を継承
def __init__(self, num_classes=10): super(MyModel, self).__init__(name='my_model') self.num_classes = num_classes # Define your layers here. self.dense_1 = layers.Dense(32, activation='relu') self.dense_2 = layers.Dense(num_classes, activation='sigmoid')
コンストラクタ(初期化:__init__
)にて、レイヤをインスタンス。
def call(self, inputs): # Define your forward pass here, # using layers you previously defined (in `__init__`). x = self.dense_1(inputs) return self.dense_2(x)
call
関数で、forward
の動作を記載します。
さて、なぜcall
関数で?という疑問になりますよね?
tf.keras.Model
をおっていくとbase_layer.pyまでいきつきます。
base_layer
にある __call__
関数で、上記の call
がよばれることになります。
__call__
は組込み関数でクラスをインスタンスして、関数のように呼ばれたときに動きます。
例でいうと、以下のようになります。
model = MyModel
model() # この時に __call__ →...→ call
今回、TensorFlow/Keras を例にしてますが、PyTorchやChainerでも__call__
にて
Model or Layerを構築しています。
forward
やbackward
関数で実装というのもあるので、
各フレームワークにしたがって書きましょう。
def compute_output_shape(self, input_shape): # You need to override this function if you want to use the subclassed model # as part of a functional-style model. # Otherwise, this method is optional. shape = tf.TensorShape(input_shape).as_list() shape[-1] = self.num_classes return tf.TensorShape(shape)
コメントにあるとおり、書いてもいいし。というやつです。
まとめ
今回、call
からの__call__
についてでした。