एक साधारण लॉजिस्टिक प्रतिगमन मॉडल MNIST पर 92% वर्गीकरण सटीकता कैसे प्राप्त करता है?


66

भले ही एमएनआईएसटी डेटासेट में सभी चित्र समान पैमाने के साथ केंद्रित हों, और बिना किसी घुमाव के सामना करते हों, उनके पास एक महत्वपूर्ण लिखावट भिन्नता है जो मुझे पहेली बनाती है कि एक रेखीय मॉडल ऐसी उच्च वर्गीकरण सटीकता कैसे प्राप्त करता है।

जहां तक ​​मैं कल्पना करने में सक्षम हूं, महत्वपूर्ण लिखावट भिन्नता को देखते हुए, अंक 784 आयामी अंतरिक्ष में रैखिक रूप से अविभाज्य होने चाहिए, अर्थात, थोड़ा जटिल (हालांकि बहुत जटिल नहीं) गैर-रेखीय सीमा होनी चाहिए जो अलग-अलग अंकों को अलग करती है। , अच्छी तरह से उद्धृत उदाहरण के समान है जहां सकारात्मक और नकारात्मक कक्षाएं किसी भी रैखिक क्लासिफायरियर द्वारा अलग नहीं की जा सकती हैं। यह मुझे चकित करता है कि बहु-वर्ग लॉजिस्टिक प्रतिगमन पूरी तरह से रैखिक सुविधाओं (कोई बहुपद सुविधाओं) के साथ इतनी उच्च सटीकता कैसे पैदा करता है।XOR

एक उदाहरण के रूप में, किसी भी पिक्सेल को छवि में दिया गया है, अंक और विभिन्न हस्तलिखित रूप उस पिक्सेल को रोशन कर सकते हैं या नहीं। इसलिए, सीखा भार के एक सेट के साथ, प्रत्येक पिक्सेल साथ-साथ भी एक अंक बना सकता है । केवल पिक्सेल मूल्यों के संयोजन के साथ यह कहना संभव है कि क्या अंक या । यह अधिकांश अंकों के जोड़े के लिए सही है। तो, लॉजिस्टिक रिग्रेशन कैसे है, जो नेत्रहीन रूप से सभी पिक्सेल मूल्यों (स्वतंत्र रूप से किसी भी पिक्सेल पर निर्भरता पर विचार किए बिना) पर अपने निर्णय को आधार बनाता है, इस तरह की उच्च सटीकता प्राप्त करने में सक्षम है।232323

मुझे पता है कि मैं कहीं न कहीं गलत हूं या मैं केवल छवियों में भिन्नता का आकलन कर रहा हूं। हालांकि, यह बहुत अच्छा होगा अगर कोई मुझे इस बात पर अंतर्ज्ञान के साथ मदद कर सकता है कि अंक 'लगभग' रैखिक रूप से अलग कैसे हैं।


स्पार्सिटी के साथ पाठ्यपुस्तक सांख्यिकीय लर्निंग पर एक नज़र डालें: लैस्सो और सामान्यीकरण 3.3.1 उदाहरण: हस्तलिखित अंक web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
एड्रियन

मैं जिज्ञासु रहा हूं: समस्या पर दंडित लीनियर मॉडल (यानी, ग्लमैनेट) जैसा कुछ कितना अच्छा है? अगर मुझे याद है, तो आप जो रिपोर्ट कर रहे हैं, वह अनपेक्षित आकार की सटीकता है।
एबी एबी

जवाबों:


84

tl; dr; भले ही यह एक छवि वर्गीकरण डेटासेट है, यह एक बहुत ही आसान काम है, जिसके लिए इनपुट से भविष्यवाणियों तक कोई भी प्रत्यक्ष मानचित्रण आसानी से पा सकता है ।


उत्तर:

यह एक बहुत ही दिलचस्प सवाल है और लॉजिस्टिक रिग्रेशन की सादगी की बदौलत आप वास्तव में इसका जवाब पा सकते हैं।

लॉजिस्टिक रिग्रेशन क्या है प्रत्येक छवि के लिए इनपुट स्वीकार करते हैं और अपनी भविष्यवाणी बनाने के लिए उन्हें वजन से गुणा करते हैं। दिलचस्प बात यह है कि इनपुट और आउटपुट (यानी कोई छिपी हुई परत) के बीच प्रत्यक्ष मानचित्रण के कारण, प्रत्येक वजन का मूल्य इस बात से मेल खाता है कि प्रत्येक वर्ग की संभावना की गणना करते समय इनपुटों में से प्रत्येक को कितना ध्यान में रखा जाता है। अब, प्रत्येक वर्ग के लिए वेट ले कर उन्हें (यानी इमेज रिज़ॉल्यूशन) में रीशैप करके , हम बता सकते हैं कि प्रत्येक वर्ग की गणना के लिए कौन-सा पिक्सेल सबसे महत्वपूर्ण हैं78478428×28

ध्यान दें, फिर से, कि ये वजन हैं

अब उपरोक्त छवि पर एक नज़र डालें और पहले दो अंकों (यानी शून्य और एक) पर ध्यान केंद्रित करें। ब्लू वेट का मतलब है कि इस पिक्सेल की तीव्रता उस वर्ग के लिए बहुत योगदान करती है और लाल मूल्यों का मतलब है कि यह नकारात्मक रूप से योगदान देता है।

अब सोचिए, एक व्यक्ति को कैसे खींचता है ? वह एक गोलाकार आकृति बनाता है जो बीच में खाली होता है। ठीक यही वज़न उठा। वास्तव में अगर कोई छवि के बीच में खींचता है, तो यह शून्य के रूप में नकारात्मक रूप से गिना जाता है। तो शून्य को पहचानने के लिए आपको कुछ परिष्कृत फिल्टर और उच्च-स्तरीय सुविधाओं की आवश्यकता नहीं है। आप बस इसके अनुसार खींचे गए पिक्सेल स्थानों और न्यायाधीश को देख सकते हैं।0

लिए एक ही बात । छवि के बीच में हमेशा एक सीधी खड़ी रेखा होती है। बाकी सभी नकारात्मक रूप से गिने जाते हैं।1

2378

इसके माध्यम से आप देख सकते हैं कि लॉजिस्टिक रिग्रेशन के पास बहुत सारी छवियों के सही होने की बहुत अच्छी संभावना है और इसीलिए यह इतना उच्च स्कोर करता है।


उपरोक्त आकृति को पुन: पेश करने के लिए कोड थोड़ा दिनांकित है, लेकिन यहां आप जाते हैं:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

11
2378

13
बेशक यह मदद करता है कि एमएनआईएसटी के नमूने केंद्रित, स्केल किए गए, और इसके विपरीत-सामान्यीकृत होते हैं इससे पहले कि क्लासिफायर कभी उन्हें देखता है। आपको "यदि शून्य का किनारा वास्तव में बॉक्स के मध्य से गुजरता है तो" जैसे प्रश्नों का समाधान नहीं करना है? क्योंकि प्री-प्रोसेसर पहले से ही सभी जीरो को समान बनाने की दिशा में एक लंबा रास्ता तय कर चुका है।
हॉब्स

1
@EricDuminil मैंने आपके सुझाव के साथ स्क्रिप्ट पर एक प्रशंसा जोड़ी। इनपुट के लिए बहुत बहुत धन्यवाद! : D
Djib2011

1
@ नीतीश अग्रवाल, यदि आपको लगता है कि यह उत्तर आपके प्रश्न का उत्तर है, तो इसे चिह्नित करने पर विचार करें।
सिंटेक्स

9
किसी ऐसे व्यक्ति के लिए जो इस तरह की प्रसंस्करण से विशेष रूप से परिचित नहीं है, लेकिन यह उत्तर यांत्रिकी का एक शानदार सहज ज्ञान युक्त उदाहरण प्रदान करता है।
क्राइसिस -ऑन स्ट्राइक-
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.