본문 바로가기
Dev/딥러닝

07-3. Tensor flow 에서 epoch 과 batch 설명 및 예제

by bsion 2018. 8. 23.
07-3. Training using Epoch and Batch

출처 : 모두를위한 머신러닝 (http://hunkim.github.io/ml/)


Tensor Flow 에서 많은 데이터를 학습시킬 때 사용하는 Epoch 과 Batch 에 대해 알아본다


큰 사이즈의 데이터를 읽고 학습시키려면 그만큼의 벡터공간도 많이 필요하므로 일정크기 만큼 나눠서 학습을 시키는것이 효과적이다. 예를들어 100만개의 데이터가 있을 경우, 10만개씩 학습을 시키고 모델은 학습된 결과를 저장해둔 상태로 다음 10만개의 데이터를 학습하는 과정을 반복한다. 이 방법이 효과적인 이유는, 100만개의 데이터를 학습시켜두었는데 새로운 10만개의 데이터가 생겼을 때 110만개를 학습시키는것이 아닌 생성해둔 모델을 불러와서 사용 할 수 있기 때문이다. Online learning 이라는 방법으로 처리한다.

  • Epoch : 전체 데이터를 학습시킨 횟수를 의미한다.
  • Batch size : 전체 데이터에서 일부만 잘라서 학습시킬경우 그 일부분의 갯수를 의미한다.

Example) 1000 개의 Training set 이 있고 batch_size=100 이라면, 1 epoch 에 필요한 iteration 은 10 이 된다.


Epoch과 Batch 를 사용한 예시

수기로 작성한 숫자를 학습시켜서 읽어들이는 유명한 예제로 연습해본다. 데이터는 Tensorflow 내장함수의 tutorial 에서 읽어온다.

In [2]:
import warnings
warnings.filterwarnings("ignore")

import tensorflow as tf
import random
import matplotlib.pyplot as plt
tf.set_random_seed(777)  # for reproducibility

from tensorflow.examples.tutorials.mnist import input_data
# Check out https://www.tensorflow.org/get_started/mnist/beginners for
# more information about the mnist dataset
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


0부터 9까지 숫자를 읽는 모델이므로 총 10개의 Y 값이 가능하다

In [3]:
nb_classes = 10


수기로 작성된 숫자의 이미지 파일의 픽셀크기가 가로28, 세로28 이므로 28 * 28 = 784 의 길이를 갖는다.

In [4]:
# MNIST data image of shape 28 * 28 = 784
X = tf.placeholder(tf.float32, [None, 784])
# 0 - 9 digits recognition = 10 classes
Y = tf.placeholder(tf.float32, [None, nb_classes])

W = tf.Variable(tf.random_normal([784, nb_classes]))
b = tf.Variable(tf.random_normal([nb_classes]))


Hypothesis 와 Cost function 은 이전에 사용한 예제와 동일하다

In [5]:
# Hypothesis (using softmax)
hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)

cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)

# Test model
is_correct = tf.equal(tf.arg_max(hypothesis, 1), tf.arg_max(Y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
WARNING:tensorflow:From <ipython-input-5-c93befc7c1ad>:8: arg_max (from tensorflow.python.ops.gen_math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `argmax` instead


batch size 는 100으로 설정하고, epoch 은 15 로 설정했다. 100개 간격으로 데이터를 읽고, 전체 데이터를 15회 학습한다는 의미가 되겠다.

In [6]:
# parameters
training_epochs = 15
batch_size = 100


학습 및 테스트

In [11]:
with tf.Session() as sess:
    # Initialize TensorFlow variables
    sess.run(tf.global_variables_initializer())
    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples / batch_size)

        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            c, _ = sess.run([cost, optimizer], feed_dict={
                            X: batch_xs, Y: batch_ys})
            avg_cost += c / total_batch

        print('Epoch:', '%04d' % (epoch + 1),
              'cost =', '{:.9f}'.format(avg_cost))

    print("Learning finished")

    # Test the model using test sets
    print("\n---------------------------")
    print("Accuracy: ", accuracy.eval(session=sess, feed_dict={
          X: mnist.test.images, Y: mnist.test.labels}))
    
    # Get one and predict
    r = random.randint(0, mnist.test.num_examples - 1)
    print("\n---------------------------")
    print("Label: ", sess.run(tf.argmax(mnist.test.labels[r:r + 1], 1)))
    print("Prediction: ", sess.run(
        tf.argmax(hypothesis, 1), feed_dict={X: mnist.test.images[r:r + 1]}))

    plt.imshow(
        mnist.test.images[r:r + 1].reshape(28, 28),
        cmap='Greys',
        interpolation='nearest')
    plt.show()
Epoch: 0001 cost = 2.833301682
Epoch: 0002 cost = 1.061346977
Epoch: 0003 cost = 0.837722787
Epoch: 0004 cost = 0.733728423
Epoch: 0005 cost = 0.669795933
Epoch: 0006 cost = 0.624772113
Epoch: 0007 cost = 0.591003590
Epoch: 0008 cost = 0.563903705
Epoch: 0009 cost = 0.541222024
Epoch: 0010 cost = 0.522424459
Epoch: 0011 cost = 0.506267789
Epoch: 0012 cost = 0.492151302
Epoch: 0013 cost = 0.479882531
Epoch: 0014 cost = 0.468705962
Epoch: 0015 cost = 0.458789315
Learning finished

---------------------------
Accuracy:  0.8948

---------------------------
Label:  [0]
Prediction:  [0]


댓글