एक बैच में प्रत्येक उदाहरण के लिए मिनीबैड ग्रेडिएंट डीसेंट वेट को कैसे अपडेट करता है?


12

यदि हम एक बैच में 10 उदाहरण कहते हैं, तो मैं समझता हूं कि हम प्रत्येक उदाहरण के लिए नुकसान का योग कर सकते हैं, लेकिन प्रत्येक उदाहरण के लिए भार को अद्यतन करने के संबंध में बैकप्रॉपैगैशन कैसे काम करता है?

उदाहरण के लिए:

  • उदाहरण 1 -> नुकसान = 2
  • उदाहरण 2 -> नुकसान = -2

इससे 0 (E = 0) का औसत नुकसान होता है, इसलिए यह प्रत्येक वजन को कैसे अपडेट करेगा और कैसे करेगा? क्या यह केवल उन बैचों के यादृच्छिककरण से है जो हम "उम्मीद" जल्दी या बाद में अभिसरण करते हैं? इसके अलावा यह केवल संसाधित किए गए अंतिम उदाहरण के लिए वजन के पहले सेट के लिए ढाल की गणना नहीं करता है?

जवाबों:


15

धीरे-धीरे वंश आपके द्वारा सुझाए गए तरीके से बहुत काम नहीं करता है लेकिन एक समान समस्या हो सकती है।

हम बैच से औसत नुकसान की गणना नहीं करते हैं, हम नुकसान फ़ंक्शन के औसत ग्रेडिएंट्स की गणना करते हैं। ग्रेडिएंट्स वजन के संबंध में नुकसान के व्युत्पन्न हैं और तंत्रिका नेटवर्क में एक वजन के लिए ढाल उस विशिष्ट उदाहरण के इनपुट पर निर्भर करता है और यह मॉडल में कई अन्य भारों पर भी निर्भर करता है।

यदि आपके मॉडल में 5 वज़न है और आपके पास 2 के आकार का मिनी-बैच है तो आपको यह मिल सकता है:

उदाहरण 1. हानि = 2,ढ़ाल=(1.5,-2.0,1.1,0.4,-0.9)

उदाहरण 2. हानि = 3,ढ़ाल=(1.2,2.3,-1.1,-0.8,-0.7)

इस मिनी-बैच में ग्रेडिएंट्स की औसत गणना की जाती है, वे हैं(1.35,0.15,0,-0.2,-0.8)

कई उदाहरणों पर औसत का लाभ यह है कि ग्रेडिएंट में भिन्नता कम है, इसलिए शिक्षण अधिक सुसंगत है और एक उदाहरण की बारीकियों पर कम निर्भर है। ध्यान दें कि तीसरे वजन के लिए औसत ढाल , यह भार इस वजन अद्यतन को नहीं बदलेगा, लेकिन यह संभवतः अगले उदाहरणों के लिए गैर-शून्य होगा जो विभिन्न भारों के साथ गणना करते हैं।0

टिप्पणियों के जवाब में संपादित करें:

ग्रेडियरों के औसत से ऊपर मेरे उदाहरण में गणना की गई है। के एक मिनी-बैच आकार के लिए जहां हम प्रत्येक उदाहरण के लिए नुकसान गणना करते हैं, जिसका उद्देश्य वजन संबंध में नुकसान के औसत ग्रेडिएंट को प्राप्त करना है ।L i w jएलमैंwजे

जिस तरह से मैंने इसे अपने उदाहरण में लिखा है मैंने प्रत्येक ग्रेडिएंट को औसतन लिखा है:एलwजे=1Σमैं=1एलमैंwजे

टिप्पणियों में आपके द्वारा लिंक किया गया ट्यूटोरियल कोड औसत हानि को कम करने के लिए टेन्सरफ्लो का उपयोग करता है।

Tensorflow का उद्देश्य1Σमैं=1एलमैं

इसे कम करने के लिए यह प्रत्येक वजन के संबंध में औसत हानि के ग्रेडिएंट्स की गणना करता है और वजन को अद्यतन करने के लिए ढाल-वंश का उपयोग करता है:

एलwजे=wजे1Σमैं=1एलमैं

अंतर को योग के अंदर लाया जा सकता है, इसलिए यह मेरे उदाहरण में दृष्टिकोण से अभिव्यक्ति के समान है।

wजे1Σमैं=1एलमैं=1Σमैं=1एलमैंwजे


पकड़ लिया। आप अभी भी बैच_साइज़ पर हुए नुकसान को औसत करना चाहेंगे? मुझे यकीन नहीं है कि अगर आप टेंसोफ़्लो से परिचित हैं, लेकिन मैं इस ट्यूटोरियल के साथ अपनी समझ को समेटने की कोशिश कर रहा हूं: टेंसोरफ़्लो। ऑर्गेट / स्टार्टेड / मनिस्ट / एबगिनर्स आप देख सकते हैं कि नुकसान बैच पर कम हो गया है (कम करें कोड)। मुझे लगता है कि टेंसरफ्लो वजन की आंतरिक गणना / औसत रखता है?
कार्बोक्पुटेड

1
@carboncomputed ओह हाँ आप सही कह रहे हैं, वे नुकसान का औसत निकालते हैं ताकि जब टेंसरफ़्लो औसत नुकसान के ग्रेडर की गणना करे तो यह प्रत्येक नुकसान के लिए ग्रेडिएंट के औसत को प्रभावी रूप से गणना कर रहा है। मैं इसके लिए गणित दिखाने के लिए अपने उत्तर को संपादित करूँगा।
ह्यूग

दिलचस्प। स्पष्टीकरण के लिए धन्यवाद। तो बस थोड़ा गहरा खोदने के लिए, वेट ग्रेडिएंट्स को प्रति पास फॉरवर्ड पास के दौरान परिकलित किया जाता है और इन्हें टेंसरफ़्लो में ऑप्टिमाइज़ेशन प्रक्रिया के दौरान गणना किया जाता है? मुझे लगता है कि मैं "लापता" हूँ जहाँ "" ये ग्रेडिएंटफ़्लो में ग्रेडिएंट हैं? मैं आगे पास और नुकसान देखता हूं, इसलिए टेंसोफ़्लो मेरे लिए हुड के तहत इन ढाल गणना / औसत कर रहा है?
कार्बोक्पुटेड

1
@carboncomputed यह Tensorflow की अपील है, यह प्रतीकात्मक गणित का उपयोग करता है और हुड के तहत भेदभाव कर सकता है
Hugh

एक साफ जवाब के लिए धन्यवाद। हालांकि, मैं समझने के लिए TF कैसे जानता है के रूप में में दिखाया गया है कि कैसे एक औसत हानि के साथ प्रचार बात नहीं पाई गई इस उदाहरण , code line 170?
पापी

-1

मिनी बैचों का उपयोग करने का कारण अच्छी मात्रा में प्रशिक्षण का उदाहरण है, ताकि इसके प्रभाव को उनके प्रभाव के औसत से कम किया जा सके, लेकिन यह भी एक पूर्ण बैच नहीं है कि कई डेटासेट के लिए स्मृति की एक बड़ी मात्रा की आवश्यकता हो सकती है। एक महत्वपूर्ण तथ्य यह है कि आपके द्वारा मूल्यांकन की जाने वाली त्रुटि हमेशा एक दूरी होती हैआपके अनुमानित आउटपुट और वास्तविक आउटपुट के बीच: इसका मतलब है कि यह नकारात्मक नहीं हो सकता है, इसलिए आपके पास नहीं हो सकता है, जैसा कि आपने कहा, 2 और -2 की एक त्रुटि जो रद्द कर देती है, लेकिन यह बदले में 4 की त्रुटि बन जाएगी। । तब आप सभी भारों के संबंध में त्रुटि के ग्रेडिएंट का मूल्यांकन करते हैं, इसलिए आप यह गणना कर सकते हैं कि कौन सा भार में परिवर्तन इसे सबसे कम करेगा। एक बार जब आप ऐसा करते हैं, तो आप अपने सीखने की दर अल्फा के परिमाण के आधार पर, उस दिशा में एक "कदम" उठाते हैं। (यह मूल अवधारणाएं हैं, मैं गहन एनएन के लिए बैकप्रॉपैगैशन के बारे में विस्तार से नहीं जा रहा हूं) एक निश्चित संख्या में युग के लिए अपने डेटासेट पर इस प्रशिक्षण को चलाने के बाद, आप अपने नेटवर्क से यह उम्मीद कर सकते हैं कि आपका सीखने का चरण बहुत बड़ा नहीं है या नहीं इसे मोड़ो। आप अभी भी एक स्थानीय न्यूनतम में समाप्त कर सकते हैं, यह डिफरेंशियल ऑप्टिमाइज़र का उपयोग करते हुए, अलग-अलग अपने वज़न को इनिशियलाइज़ करके और नियमित करने की कोशिश करके इससे बचा जा सकता है।


बस जोड़ने के लिए: हम मिनी-बैचों का उपयोग ज्यादातर कम्प्यूटेशनल दक्षता के लिए करते हैं। हमारे पास वज़न की सटीकता और भार को अद्यतन करने की आवृत्ति के बीच एक व्यापार-बंद है। मेमोरी में फिट नहीं होने के लिए डेटा बहुत बड़ा होना चाहिए।
सुकस ग्रैड

मैं प्रत्येक को समझता हूं, लेकिन हम एक विशिष्ट बैच के लिए अपना वजन कैसे अपडेट करते हैं? प्रत्येक उदाहरण के लिए वजन ढाल भी अभिव्यक्त किया गया है?
कार्बोक्पुटेड

नहीं, कुल बैच त्रुटि पर केवल एक ग्रेडिएंट है, जो डेरिवेटिव का वेक्टर है। इसका मतलब यह है कि हम एक बार अपने वेट को ग्रेडिएंट के आधार पर अपडेट करते हैं, यानी इस मिनी बैच पर अपडेट की दिशा सबसे ज्यादा घटती है। ग्रेडिएंट आंशिक व्युत्पन्न से बना है, जो प्रत्येक भार के संबंध में व्युत्पन्न pf मिनी बैच त्रुटि है: यह हमें बताता है कि क्या प्रत्येक वजन छोटा या बड़ा होना चाहिए, और कितना। सभी भारों को बैच के लिए एक अद्यतन मिलता है, उस मिनी बैच पर त्रुटि को कम करने के लिए, जो अन्य मिनी बैचों से स्वतंत्र है।
डांटे
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.