본문 바로가기
Dev/딥러닝

06-1. Tensor Flow 로 3종류 이상 Classfication (Soft max Classifier)

by bsion 2018. 8. 20.
06-1. Softmax Classification

출처 : 모두를위한 머신러닝 (http://hunkim.github.io/ml/)


이론




Multinomial Classification

3종류 이상의 classification


Softmax function



In [2]:
import tensorflow as tf

x_data = [[1, 2, 1, 1], [2, 1, 3, 2], [3, 1, 3, 4], [4, 1, 5, 5], [1, 7, 5, 5],
         [1, 2, 5, 6], [1, 6, 6, 6], [1, 7, 7, 7]]

# one hot incoding.. 한가지만 1 나머진0
#             C          C          C          B          B          B          A          A
y_data = [[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0]]

X = tf.placeholder("float", [None, 4])
Y = tf.placeholder("float", [None, 3])
nb_classes = 3

W = tf.Variable(tf.random_normal([4, nb_classes]), name='weight')
b = tf.Variable(tf.random_normal([nb_classes]), name='bias')

# tf.nn.softmax = exp(Logits) / reduce_sum(exp(Logits), dim)
hypothesis = tf.nn.softmax(tf.matmul(X,W) + b)

# Cross entropy cost(loss)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for step in range(4001):
        sess.run(optimizer, feed_dict={X: x_data, Y: y_data})
        if step % 400 == 0:
            print(step, sess.run(cost, feed_dict={X: x_data, Y: y_data}))
            
    # TEST
    print("---------------------\nTEST")
    a = sess.run(hypothesis, feed_dict={X: [[1, 11, 7, 9]]})
    print(a)                             # 이대로 출력할 경우 arg_max 를 사용하라는 경고가 발생한다.
    print(sess.run(tf.arg_max(a, 1)))    # arg_max : 이 matrix 에서 제일 높은값을 갖는 인덱스 추출
0 9.722477
400 0.34384868
800 0.23956646
1200 0.19624427
1600 0.16640283
2000 0.14438885
2400 0.12744662
2800 0.11400323
3200 0.10308075
3600 0.09403515
4000 0.08642442
---------------------
TEST
[[1.8871509e-04 9.9981064e-01 6.7583107e-07]]
WARNING:tensorflow:From <ipython-input-2-42b47e12e19a>:36: arg_max (from tensorflow.python.ops.gen_math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `argmax` instead
[1]


댓글