본문 바로가기
Dev/딥러닝

07-1. Tensor Flow 에서 Learning rate 이란

by bsion 2018. 8. 22.
07-1. Learning rate

출처 : 모두를위한 머신러닝 (http://hunkim.github.io/ml/)

Tensor Flow 에서 사용하는 Learning rate 에 관하여 이론과 예제를 통해 설명한다.


Learning Rate in Gradient Descent Optimizer

Gradient descent optimizer 를 사용할 때 learning_rate 을 사용하였다. learning_rate 을 직관적으로 표현하면 다음 그림처럼 나타낼 수 있다.


  • Learning rate 이 너무 클 경우, 정확한 minimum point 에 도달하지 못하거나 데이터가 튀어서 무한대로 나갈수가 있다.
  • Learning rate 이 너무 작을경우, 변화가 너무 적어서 최대 step 에 도달해도 minimum point를 찾지 못할 수가 있다.
  • 따라서 Cost function 에 따라 적당한 learning rate 이 다르므로, 테스트를 통해 적당한 값을 찾아내는것이 중요하다.

예제

Learning rate 을 테스트하기위해 기본적인 작업을 셋팅함

In [8]:
import tensorflow as tf
tf.set_random_seed(777)  # for reproducibility

x_data = [[1, 2, 1],
          [1, 3, 2],
          [1, 3, 4],
          [1, 5, 5],
          [1, 7, 5],
          [1, 2, 5],
          [1, 6, 6],
          [1, 7, 7]]
y_data = [[0, 0, 1],
          [0, 0, 1],
          [0, 0, 1],
          [0, 1, 0],
          [0, 1, 0],
          [0, 1, 0],
          [1, 0, 0],
          [1, 0, 0]]


# Evaluation our model using this test dataset
x_test = [[2, 1, 1],
          [3, 1, 2],
          [3, 3, 4]]
y_test = [[0, 0, 1],
          [0, 0, 1],
          [0, 0, 1]]

X = tf.placeholder("float", [None, 3])
Y = tf.placeholder("float", [None, 3])

W = tf.Variable(tf.random_normal([3, 3]))
b = tf.Variable(tf.random_normal([3]))

# tf.nn.softmax computes softmax activations
# softmax = exp(logits) / reduce_sum(exp(logits), dim)
hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)

# Cross entropy cost/loss
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))

매우 큰 Learning rate

5번재 step 에서 무한대가 나오고, 6 번째 step 에서부터 nan 이 출력되었다. 따라서 Cost function 이 nan 이 출력되는 경우 learning rate 이 너무 커서 그럴 수 있다.

In [7]:
optimizer = tf.train.GradientDescentOptimizer(
    learning_rate=1.5).minimize(cost)

# Correct prediction Test model
prediction = tf.arg_max(hypothesis, 1)
is_correct = tf.equal(prediction, tf.arg_max(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

# Launch graph
with tf.Session() as sess:
    # Initialize TensorFlow variables
    sess.run(tf.global_variables_initializer())

    for step in range(201):
        cost_val, W_val, _ = sess.run(
            [cost, W, optimizer], feed_dict={X: x_data, Y: y_data})
        if step < 10 or step > 195:
            print(step, cost_val, W_val)

    # predict
    print("Prediction:", sess.run(prediction, feed_dict={X: x_test}))
    # Calculate the accuracy
    print("Accuracy: ", sess.run(accuracy, feed_dict={X: x_test, Y: y_test}))
0 5.7320304 [[ 0.8026957  0.6786129 -1.2172831]
 [-0.3051686 -0.3032113  1.508257 ]
 [ 0.7572236 -0.7008909 -2.108204 ]]
1 23.149357 [[-0.30548945  1.2298502  -0.6603353 ]
 [-4.3907      2.2967086   2.9938684 ]
 [-3.345107    2.0974321  -0.80419564]]
2 27.27978 [[ 0.06951055  0.29449674 -0.09998183]
 [-1.9531999  -1.6362796   4.4893565 ]
 [-0.9076071  -1.6502014   0.50593793]]
3 8.668001 [[ 0.44451022  0.85699666 -1.0374814 ]
 [ 0.48429942  0.98872024 -0.57314277]
 [ 1.5298926   1.1622986  -4.7440615 ]]
4 5.7710896 [[ 0.12396157  0.6150459  -0.47498202]
 [ 0.22003089 -0.24701011  0.92685604]
 [ 0.9603522   0.41933942 -3.431562  ]]
5 inf [[-0.95243084  1.1303774   0.08607888]
 [-3.786516    2.2624538   2.4239388 ]
 [-3.0717096   3.1403794  -2.1205401 ]]
6 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
7 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
8 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
9 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
196 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
197 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
198 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
199 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
200 nan [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
Prediction: [0 0 0]
Accuracy:  0.0

매우 작은 Learning rate

Cost function의 변화가 너무 미미하다.

In [9]:
optimizer = tf.train.GradientDescentOptimizer(
    learning_rate=1e-10).minimize(cost)

# Correct prediction Test model
prediction = tf.arg_max(hypothesis, 1)
is_correct = tf.equal(prediction, tf.arg_max(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

# Launch graph
with tf.Session() as sess:
    # Initialize TensorFlow variables
    sess.run(tf.global_variables_initializer())

    for step in range(201):
        cost_val, W_val, _ = sess.run(
            [cost, W, optimizer], feed_dict={X: x_data, Y: y_data})
        if step < 10 or step > 195:
            print(step, cost_val, W_val)

    # predict
    print("Prediction:", sess.run(prediction, feed_dict={X: x_test}))
    # Calculate the accuracy
    print("Accuracy: ", sess.run(accuracy, feed_dict={X: x_test, Y: y_test}))
0 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
1 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
2 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
3 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
4 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
5 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
6 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
7 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
8 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
9 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
196 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
197 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
198 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
199 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
200 8.805746 [[ 0.5881828   0.37383866 -0.00470501]
 [-0.34741014 -1.4699485   0.5476274 ]
 [ 0.15381092  0.40726185  1.2065208 ]]
Prediction: [2 2 2]
Accuracy:  1.0


댓글