स्टोकेस्टिक ग्रेडिएंट डिसेंट्रेंट कैसे मानक ढाल डिसेंट की तुलना में समय बचा सकता है?


16

मानक ढाल वंश पूरे प्रशिक्षण डाटासेट के लिए ढाल की गणना करेगा।

for i in range(nb_epochs):
  params_grad = evaluate_gradient(loss_function, data, params)
  params = params - learning_rate * params_grad

युगों की पूर्व-निर्धारित संख्या के लिए, हम पहले संपूर्ण डाटासेट के लिए नुकसान फ़ंक्शन के ग्रेडिएंट वेक्टर वेट्स_ग्रेड की गणना करते हैं, जो हमारे पैरामीटर वेक्टर पैर्ट्स को प्रभावित करता है।

इसके विपरीत स्टोचस्टिक ग्रेडिएंट डीसेंट प्रत्येक प्रशिक्षण उदाहरण x (i) और लेबल y (i) के लिए एक पैरामीटर अपडेट करता है।

for i in range(nb_epochs):
  np.random.shuffle(data)
  for example in data:
    params_grad = evaluate_gradient(loss_function, example, params)
    params = params - learning_rate * params_grad

कहा जाता है कि यह बहुत तेज है। हालांकि, मुझे समझ में नहीं आता है कि अगर हम अभी भी सभी डेटा बिंदुओं पर लूप रखते हैं तो यह कितना तेज हो सकता है। क्या जीडी में ढाल की गणना प्रत्येक डेटा बिंदु के लिए जीडी की गणना की तुलना में बहुत धीमी है?

यहां से कोड आता है


1
दूसरे मामले में आप पूरे डेटा सेट को अनुमानित करने के लिए एक छोटा बैच लेंगे। यह आमतौर पर बहुत अच्छी तरह से काम करता है। तो भ्रामक भाग शायद यह है कि ऐसा लगता है कि दोनों मामलों में युगों की संख्या समान है, लेकिन आपको मामले में कई युगों की आवश्यकता नहीं होगी। उन दो तरीकों के लिए "हाइपरपरमेटर्स" अलग होंगे: जीडी nb -epochs! = SGD nb_epochs। आइए तर्क के उद्देश्य के लिए कहें: GD nb_epochs = SGD उदाहरण * nb_epochs, ताकि कुल लूप की संख्या समान हो, लेकिन ग्रेडिएंट की गणना का रास्ता SGD में तेज है।
नीमा मौसवी

सीवी पर यह उत्तर एक अच्छा और संबंधित है।
झब्बार

जवाबों:


24

संक्षिप्त जवाब:

  • कई बड़े डेटा सेटिंग में (कई मिलियन डेटा पॉइंट्स कहें), लागत या ग्रेडिएंट की गणना में बहुत लंबा समय लगता है, क्योंकि हमें कुछ डेटा पॉइंट्स पर योग करने की आवश्यकता होती है।
  • हमें दिए गए पुनरावृत्ति में लागत को कम करने के लिए सटीक ढाल की आवश्यकता नहीं है । ढाल के कुछ सन्निकटन ठीक काम करेगा।
  • स्टोचस्टिक ग्रेडिएंट सभ्य (SGD) केवल एक डेटा बिंदु का उपयोग करके ढाल को अनुमानित करता है। इसलिए, सभी डेटा की तुलना में ढाल का मूल्यांकन करने में बहुत समय बचता है।
  • पुनरावृत्तियों की "उचित" संख्या के साथ (यह संख्या हजारों की संख्या में हो सकती है, और डेटा बिंदुओं की संख्या से बहुत कम है, जो लाखों हो सकती है), स्टोकेस्टिक क्रमिक सभ्य को एक उचित अच्छा समाधान मिल सकता है।

लंबा जवाब:

एंड्रयू एनजी के मशीन लर्निंग कोर्टसेरा कोर्स के बाद मेरा अंकन है। यदि आप इससे परिचित नहीं हैं, तो आप यहाँ व्याख्यान श्रृंखला की समीक्षा कर सकते हैं

चलो चुकता नुकसान पर प्रतिगमन मान लेते हैं, लागत समारोह है

J(θ)=12mi=1m(hθ(x(i))y(i))2

और ढाल है

dJ(θ)dθ=1mi=1m(hθ(x(i))y(i))x(i)

ग्रेडिएंट सभ्य (GD) के लिए, हम पैरामीटर को अपडेट करते हैं

θnew=θoldα1mi=1m(hθ(x(i))y(i))x(i)

1/mx(i),y(i)

θnew=θoldα(hθ(x(i))y(i))x(i)

यहाँ हम समय क्यों बचा रहे हैं:

मान लीजिए कि हमारे पास 1 बिलियन डेटा पॉइंट हैं।

  • जीडी में, मापदंडों को एक बार अपडेट करने के लिए, हमें सटीक (सटीक) ढाल की आवश्यकता है। इसके लिए 1 अद्यतन करने के लिए इन 1 बिलियन डेटा बिंदुओं को योग करने की आवश्यकता है।

  • स्वस्थानी में, हम इसे सटीक ढाल के बजाय एक अनुमानित ढाल प्राप्त करने की कोशिश के रूप में सोच सकते हैं । सन्निकटन एक डेटा बिंदु (या कई डेटा बिंदु जिन्हें मिनी बैच कहा जाता है) से आ रहा है। इसलिए, SGD में, हम मापदंडों को बहुत जल्दी अपडेट कर सकते हैं। इसके अलावा, यदि हम सभी डेटा (जिसे एक युग कहा जाता है) पर "लूप" करते हैं, तो वास्तव में हमारे पास 1 बिलियन अपडेट हैं।

चाल यह है कि, SGD में आपको 1 बिलियन पुनरावृत्तियों / अपडेट की आवश्यकता नहीं है, लेकिन बहुत कम पुनरावृत्तियों / अपडेट्स, 1 मिलियन कहते हैं, और आपके पास उपयोग करने के लिए "बहुत अच्छा" मॉडल होगा।


मैं विचार को प्रदर्शित करने के लिए एक कोड लिख रहा हूं। हम पहले रेखीय प्रणाली को सामान्य समीकरण द्वारा हल करते हैं, फिर इसे SGD के साथ हल करते हैं। फिर हम पैरामीटर मान और अंतिम उद्देश्य फ़ंक्शन मान के संदर्भ में परिणामों की तुलना करते हैं। बाद में इसकी कल्पना करने के लिए, हमारे पास धुन करने के लिए 2 पैरामीटर होंगे।

set.seed(0);n_data=1e3;n_feature=2;
A=matrix(runif(n_data*n_feature),ncol=n_feature)
b=runif(n_data)
res1=solve(t(A) %*% A, t(A) %*% b)

sq_loss<-function(A,b,x){
  e=A %*% x -b
  v=crossprod(e)
  return(v[1])
}

sq_loss_gr_approx<-function(A,b,x){
  # note, in GD, we need to sum over all data
  # here i is just one random index sample
  i=sample(1:n_data, 1)
  gr=2*(crossprod(A[i,],x)-b[i])*A[i,]
  return(gr)
}

x=runif(n_feature)
alpha=0.01
N_iter=300
loss=rep(0,N_iter)

for (i in 1:N_iter){
  x=x-alpha*sq_loss_gr_approx(A,b,x)
  loss[i]=sq_loss(A,b,x)
}

परिणाम:

as.vector(res1)
[1] 0.4368427 0.3991028
x
[1] 0.3580121 0.4782659

124.1343123.0355

यहां पुनरावृत्तियों पर लागत फ़ंक्शन मान हैं, हम देख सकते हैं कि यह नुकसान को प्रभावी रूप से कम कर सकता है, जो इस विचार को दिखाता है: हम ग्रेडिएंट को अनुमानित करने और "अच्छे पर्याप्त" परिणाम प्राप्त करने के लिए डेटा के सबसेट का उपयोग कर सकते हैं।

यहाँ छवि विवरण दर्ज करें

यहाँ छवि विवरण दर्ज करें

1000sq_loss_gr_approx3001000


मैंने सोचा कि "गति" के बारे में तर्क इस बात के बारे में अधिक है कि स्थानीय इष्टतम को बदलने के लिए कितने ऑपरेशन / पुनरावृत्तियों की आवश्यकता है? (और यह भी कि स्टोकेस्टिक ग्रेडिएंट
डीसेंट

जहाँ तक मुझे समझ में आया, अजगर कोड में मैंने "डेटा" प्रदान किया है, -वर्तनीय समान है। मिनी बैच ग्रेडिएंट सभ्य - कोड एसडीजी से अलग है (और वास्तव में वहां वह केवल डेटा का एक छोटा सा हिस्सा उपयोग करता है)। इसके अलावा, आपके द्वारा दिए गए स्पष्टीकरण में, हालांकि हमें एसडीजी में राशि से छुटकारा मिलता है, फिर भी हम प्रत्येक डेटा बिंदु के लिए अपडेट की गणना करते हैं। मुझे अभी भी समझ में नहीं आया है कि प्रत्येक डेटा बिंदु पर लूप करते समय एक पैरामीटर को कैसे अपडेट किया जाता है, बस एक बार में सभी डेटा पॉइंट पर राशि लेने की तुलना में तेज़ है।
अलीना

@ GeoMatt22 लिंक में मैंने इसे प्रदान किया है: "दूसरी ओर, यह अंत में सटीक न्यूनतम करने के लिए अभिसरण को जटिल बनाता है, क्योंकि एसडब्ल्यूडी निगरानी बनाए रखेगा।" मतलब यह बेहतर ऑप्टिमा में परिवर्तित नहीं होता है। या मुझे गलत लगा?
अलीना

@ तोंजा मैं कोई विशेषज्ञ नहीं हूं, लेकिन उदाहरण के लिए गहरे शिक्षण में यह अत्यधिक प्रभावशाली पेपर स्टोचस्टिक ढाल वंश के लिए "अधिक तेज़ विश्वसनीय प्रशिक्षण" तर्क देता है। ध्यान दें कि यह "कच्चे" संस्करण का उपयोग नहीं करता है, लेकिन सीखने की दर (समन्वय-निर्भर) को निर्धारित करने के लिए विभिन्न वक्रता अनुमानों का उपयोग करता है।
जियोमैट

1
@ तोंजा, हाँ। ढाल का कोई भी "कमजोर" सन्निकटन काम करेगा। आप "ग्रेडिंग बूस्टिंग" की जांच कर सकते हैं, जो समान विचार है। दूसरी ओर, मैं विचार को प्रदर्शित करने के लिए कुछ कोड लिख रहा हूं। जब यह तैयार होगा तब मैं इसे पोस्ट करूंगा।
Haitao Du
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.