機械学習のExampleから覚えるPython(関数の引数)

今までPythonを感覚的に使っていたので、改めて文法を知ろうかなと。
その際にいま流行りの機械学習(深層学習)のExampleを例にすると
わかりやすいのかなと思ったので書いてみる。

※基本的には Python3.x系のつもりで記載してます

Example

https://www.tensorflow.org/tutorials/ に記載されている

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

※本記事記載の時点のコードです。

関数の引数

前回 機械学習のExampleから覚えるPython(関数呼び出し)にて、load_data 関数の引数 pathが呼び出し時に省略されていることが確認できました。

関数の引数については3つの指定ができます。

  • dafaultによる省略
  • 引数を入力(引数名無し)
  • 引数名を指定

defaultによる省略

関数定義側で、default値が設定している場合に
呼び出し側が default値で良い場合省略することができます。

  • 関数定義側
def load_data(path='mnist.npz'):
  • 呼び出し側
(x_train, y_train),(x_test, y_test) = mnist.load_data()

pathmnist.npzの値で処理されます。

引数を入力(引数名無し)

model.fit(x_train, y_train, epochs=5)

このコードを見てみます。
定義側はこのようになってます。

  def fit(self,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_freq=1,
          max_queue_size=10,
          workers=1,
          use_multiprocessing=False,
          **kwargs):

self は一旦無視していただいて、x_train, y_trainfit関数の
どの引数にあたるかというと、順番に設定しています。
つまり、以下のようになります。

#         x=x_train, y=y_train
model.fit(x_train, y_train, epochs=5)

引数名を指定

先ほどのコードの epochsの設定が該当します。

model.fit(x_train, y_train, epochs=5)

fit関数で定義されている epochsを指定して 5という値を入力しています。

可変長引数

関数の定義では、可変長引数 を定義することができます。
fit関数でも最後の引数に出てきていた **kwargsが該当します。

こちらのページが参考になると思いますので、リンク貼っときます。