प्रतिगमन के लिए सीएनएन आर्किटेक्चर?


32

मैं एक प्रतिगमन समस्या पर काम कर रहा हूं जहां इनपुट एक छवि है, और लेबल 80 और 350 के बीच एक निरंतर मूल्य है। प्रतिक्रिया के बाद चित्र कुछ रसायनों के होते हैं। जो रंग निकलता है, वह किसी अन्य रसायन की एकाग्रता को इंगित करता है जिसे छोड़ दिया जाता है, और यही वह मॉडल है जो आउटपुट के लिए है - उस रसायन की एकाग्रता। छवियों को घुमाया जा सकता है, फ़्लिप किया जा सकता है, प्रतिबिंबित किया जा सकता है, और अपेक्षित आउटपुट अभी भी समान होना चाहिए। इस तरह का विश्लेषण वास्तविक प्रयोगशालाओं में किया जाता है (बहुत ही विशिष्ट मशीनें रंग विश्लेषण का उपयोग करके रसायनों की एकाग्रता का उत्पादन करती हैं जैसे मैं इस मॉडल को करने के लिए प्रशिक्षण दे रहा हूं)।

अब तक मैंने केवल वीजीजी से जुड़े मॉडल के साथ प्रयोग किया है (कनव-कनव-कनव-पूल ब्लॉक के कई क्रम)। अधिक हाल के आर्किटेक्चर (इंसेप्शन, रेसनेट्स आदि) के साथ प्रयोग करने से पहले, मैंने सोचा कि अगर अन्य चित्र जो आमतौर पर छवियों का उपयोग करके प्रतिगमन के लिए उपयोग किए जाते हैं, तो मैं शोध करूँगा।

डेटासेट इस तरह दिखता है:

यहां छवि विवरण दर्ज करें

डेटासेट में लगभग 5,000 250x250 नमूने हैं, जिन्हें मैंने 64x64 में बदल दिया है, इसलिए प्रशिक्षण आसान है। एक बार जब मुझे एक आशाजनक वास्तुकला मिल जाता है, तो मैं बड़े संकल्प चित्रों के साथ प्रयोग करूँगा।

अब तक, मेरे सबसे अच्छे मॉडल में लगभग 0.3 के प्रशिक्षण और सत्यापन सेट दोनों पर एक औसत चुकता त्रुटि है, जो मेरे उपयोग के मामले में स्वीकार्य से दूर है।

मेरा सबसे अच्छा मॉडल अब तक इस तरह दिखता है:

// pseudo code
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = dropout()->conv2d(x, filters=128, kernel=[1, 1])->batch_norm()->relu()
x = dropout()->conv2d(x, filters=32, kernel=[1, 1])->batch_norm()->relu()

y = dense(x, units=1)

// loss = mean_squared_error(y, labels)

सवाल

छवि इनपुट से प्रतिगमन आउटपुट के लिए एक उपयुक्त वास्तुकला क्या है?

संपादित करें

मैंने अपने स्पष्टीकरण को फिर से परिभाषित किया है और सटीकता के उल्लेखों को हटा दिया है।

संपादित करें 2

मैंने अपने प्रश्न का पुनर्गठन किया है इसलिए उम्मीद है कि यह स्पष्ट है कि मैं आखिर क्या हूँ


4
सटीकता एक ऐसा उपाय नहीं है जिसे सीधे प्रतिगमन समस्याओं पर लागू किया जा सकता है। जब आप कहते हैं कि आपकी सटीकता 30% है तो आपका क्या मतलब है? सटीकता वास्तव में केवल वर्गीकरण कार्यों पर लागू होती है, प्रतिगमन पर नहीं।
न्यूक्लियर वैंग

1
"30% समय सही ढंग से भविष्यवाणी करता है " से आपका क्या अभिप्राय है ? क्या आप वास्तव में प्रतिगमन कर रहे हैं?
फायरबग

1
आप इस समस्या को प्रतिगमन क्यों कहते हैं? क्या आप लेबल में वर्गीकृत करने का प्रयास नहीं कर रहे हैं? लेबल कार्डिनल हैं?
अक्कल

2
मैं ठीक वैसी ही बात नहीं चाहता जैसा कि vgg। मैं कुछ ऐसा कर रहा हूं, जिसका अर्थ है अधिकतम पूलिंग के बाद मिलने वाली श्रंखला की श्रृंखला, जिसके बाद पूरी तरह से जुड़ा हुआ। छवियों के साथ काम करने के लिए एक सामान्य दृष्टिकोण की तरह लगता है। लेकिन फिर, यह मेरे मूल प्रश्न का पूरा बिंदु है। इन सभी टिप्पणियों की तरह लगता है, हालांकि मेरे लिए व्यावहारिक, पूरी तरह से मैं पहले स्थान पर पूछ रहा हूँ की बात याद आती है।
रॉडरिगो-सिल्वेरा

1
इसके अलावा, यदि आप समस्या का बेहतर विवरण देते हैं तो हम बेहतर मदद दे सकते हैं। 1) चित्र क्या हैं? उनका संकल्प क्या है? क्या संबंध छवियों और आपकी प्रतिक्रिया, के बीच है ? इस रिश्ते रोटेशन-अपरिवर्तनीय, यानी, अगर मैं एक मनमाना कोण से अपने परिपत्र छवि को घुमाने है θ , मैं उम्मीद कर y परिवर्तन के लिए? 2) क्या आप जानते हैं कि एक वीजीजी-नेट वास्तुकला को प्रशिक्षित करने के लिए 5000 छवियां एक दुख है? क्या आपने अपनी वास्तुकला के मापदंडों की संख्या की गणना की है? क्या कोई ऐसा तरीका है जिससे आप अधिक चित्र प्राप्त कर सकते हैं? यदि आप नहीं कर सकते हैं, तो शायद आपको जरूरत y[80,350]θy
पड़े

जवाबों:


42

सबसे पहले एक सामान्य सुझाव: आप जिस विषय से परिचित नहीं हैं उस विषय पर प्रयोग करना शुरू करने से पहले एक साहित्य खोज करें। आप अपने आप को बहुत समय बचा लेंगे।

इस मामले में, मौजूदा कागजात को देखकर आपने गौर किया होगा

  1. प्रतिगमन के लिए सीएनएन का कई बार उपयोग किया गया है: यह एक क्लासिक है लेकिन यह पुराना है (हाँ, 3 साल डीएल में पुराना है)। इस कार्य के लिए एक अधिक आधुनिक पेपर ने एलेक्सनेट का उपयोग नहीं किया होगा। यह अधिक हाल ही में है, लेकिन यह एक बहुत अधिक जटिल समस्या (3 डी रोटेशन) के लिए है, और वैसे भी मैं इससे परिचित नहीं हूं।
  2. सीएनएन के साथ प्रतिगमन एक तुच्छ समस्या नहीं है। पहले पेपर को फिर से देखते हुए, आप देखेंगे कि उनके पास एक समस्या है जहां वे मूल रूप से अनंत डेटा उत्पन्न कर सकते हैं। उनका उद्देश्य 2D चित्रों को ठीक करने के लिए आवश्यक रोटेशन कोण की भविष्यवाणी करना है। इसका मतलब यह है कि मैं मूल रूप से अपना प्रशिक्षण सेट ले सकता हूं और प्रत्येक छवि को मनमाने कोण से घुमाकर इसे बढ़ा सकता हूं, और मैं एक मान्य, बड़ा प्रशिक्षण सेट प्राप्त करूंगा। इस प्रकार यह समस्या अपेक्षाकृत सरल है, जहाँ तक डीप लर्निंग की समस्याएँ हैं। वैसे, अन्य डेटा वृद्धि के ट्रिक्स पर ध्यान दें जो वे उपयोग करते हैं:

    हम अनुवाद का उपयोग करते हैं (छवि चौड़ाई का 5% तक), सीमा में चमक समायोजन [(0.2, 0.2], γ ∈ [.50.5, 0.1] के साथ गामा समायोजन और सीमा में एक मानक विचलन के साथ गॉसियन पिक्सेल शोर [0] , 0.02]।

    k

    yxα=atan2(y,x)>11%अधिकतम संभव त्रुटि की। उन्होंने श्रृंखला में दो नेटवर्क का उपयोग करके थोड़ा बेहतर किया: पहला वर्गीकरण का प्रदर्शन करेगा (भविष्यवाणी करें कि क्या कोण या वर्ग), तब छवि, जो पहले नेटवर्क द्वारा भविष्यवाणी की गई राशि से घूमती है, दूसरे तंत्रिका नेटवर्क (प्रतिगमन के लिए, इस समय) को खिलाया जाएगा, जो अंतिम अतिरिक्त रोटेशन की भविष्यवाणी करेगा रेंज।[180°,90°],[90°,0°],[0°,90°][90°,180°][45°,45°]

    बहुत सरल (घुमाए गए MNIST) समस्या पर, आप कुछ बेहतर प्राप्त कर सकते हैं , लेकिन फिर भी आप RMSE त्रुटि से नीचे नहीं जाते हैं, जो अधिकतम संभावित त्रुटि का है।2.6%

तो, हम इससे क्या सीख सकते हैं? सबसे पहले, कि 5000 छवियां आपके कार्य के लिए एक छोटा डेटा सेट है। पहले पेपर में एक नेटवर्क का उपयोग किया गया था, जो छवियों के समान था, जिसके लिए वे प्रतिगमन कार्य सीखना चाहते थे: न केवल आपको उससे अलग कार्य सीखने की आवश्यकता है, जिसके लिए वास्तुकला को डिजाइन किया गया था (वर्गीकरण), लेकिन आपका प्रशिक्षण सेट doesn सभी प्रशिक्षण सेटों की तरह कुछ भी न देखें जिस पर ये नेटवर्क आमतौर पर प्रशिक्षित होते हैं (CIFAR-10/100 या ImageNet)। इसलिए आपको शायद ट्रांसफर लर्निंग से कोई लाभ नहीं मिलेगा। MATLAB उदाहरण में 5000 छवियां थीं, लेकिन वे काले और सफेद थे और शब्दार्थ सभी समान थे (ठीक है, यह आपका मामला भी हो सकता है)।

फिर, कैसे यथार्थवादी 0.3 से बेहतर कर रहा है? हमें सबसे पहले समझना चाहिए कि 0.3 औसत नुकसान से आपका क्या मतलब है। क्या आपका मतलब है कि RMSE त्रुटि 0.3 है,

1Ni=1N(h(xi)yi)2

जहां आपके प्रशिक्षण सेट का आकार है (इस प्रकार, ), छवि के लिए आपके CNN का आउटपुट है और रासायनिक की संगत एकाग्रता है? बाद से , फिर यह मानते हुए कि आप अपने सीएनएन की भविष्यवाणियों को और ३५० के बीच क्लिप करते हैं (या आप सिर्फ एक अंतराल का उपयोग करके उन्हें उस अंतराल में फिट बनाते हैं), आपको से कम त्रुटि हो रही है । गंभीरता से, आप क्या उम्मीद करते हैं? यह मुझे एक बड़ी त्रुटि नहीं लगती है।NN<5000h(xi)xiyiyi[80,350]0.12%

इसके अलावा, बस अपने नेटवर्क में मापदंडों की संख्या की गणना करने का प्रयास करें: मैं जल्दी में हूं और मैं मूर्खतापूर्ण गलतियां कर सकता हूं, इसलिए हर तरह summaryसे जो भी फ्रेमवर्क आप उपयोग कर रहे हैं, उससे कुछ फ़ंक्शन के साथ मेरी गणनाओं की दोहरी जांच करें । हालांकि, मोटे तौर पर मैं कहूंगा कि आपके पास है

9×(3×32+2×32×32+32×64+2×64×64+64×128+2×128×128)+128×128+128×32+32×32×32=533344

(ध्यान दें कि मैंने बैच नॉर्म्स लेयर्स के पैरामीटर्स को छोड़ दिया है, लेकिन वे लेयर के लिए सिर्फ 4 पैरामीटर्स हैं, ताकि उन्हें फर्क न पड़े)। आपके पास आधा मिलियन पैरामीटर और 5000 उदाहरण हैं ... आप क्या उम्मीद करेंगे? ज़रूर, मापदंडों की संख्या एक तंत्रिका नेटवर्क (यह एक गैर-पहचान योग्य मॉडल) की क्षमता के लिए एक अच्छा संकेतक नहीं है, लेकिन फिर भी ... मुझे नहीं लगता कि आप इससे बेहतर कर सकते हैं, लेकिन आप एक कोशिश कर सकते हैं कुछ बातें:

  • सभी इनपुट को सामान्य करें (उदाहरण के लिए, -1 और 1 के बीच प्रत्येक पिक्सेल के RGB तीव्रता को फिर से व्यवस्थित करें, या मानकीकरण का उपयोग करें) और सभी आउटपुट। यदि आपके पास अभिसरण मुद्दे हैं तो यह विशेष रूप से मदद करेगा।
  • ग्रेस्केल पर जाएं: यह आपके इनपुट चैनलों को 3 से 1. कम कर देगा। आपकी सभी छवियां अपेक्षाकृत समान रंगों की होंगी। क्या आप सुनिश्चित हैं कि यह वह रंग है जिसकी आवश्यकता भविष्यवाणी करने के लिए है , न कि गहरे या चमकीले क्षेत्रों के अस्तित्व की? शायद आपको यकीन हो (मैं विशेषज्ञ नहीं हूं): इस मामले में इस सुझाव को छोड़ दें।y
  • डेटा वृद्धि: चूंकि आपने कहा था कि फ़्लिप करना, किसी अनियंत्रित कोण द्वारा घूमना या आपकी छवियों को मिरर करना एक ही आउटपुट में परिणामित होना चाहिए, आप अपने डेटा के आकार को बहुत बढ़ा सकते हैं । ध्यान दें कि बड़े डेटासेट के साथ प्रशिक्षण सेट पर त्रुटि बढ़ जाएगी: हम यहां जो देख रहे हैं वह प्रशिक्षण सेट हानि और परीक्षण सेट नुकसान के बीच एक छोटा अंतर है। इसके अलावा, यदि प्रशिक्षण सेट हानि बहुत बढ़ जाती है, तो यह अच्छी खबर हो सकती है: इसका मतलब यह हो सकता है कि आप ओवरफिटिंग के जोखिम के बिना इस बड़े प्रशिक्षण सेट पर एक गहन नेटवर्क को प्रशिक्षित कर सकते हैं। अधिक परतें जोड़ने का प्रयास करें और देखें कि क्या अब आपको एक छोटा प्रशिक्षण सेट और परीक्षण सेट नुकसान मिलता है। अंत में, आप ऊपर उद्धृत अन्य डेटा वृद्धि ट्रिक्स को भी आज़मा सकते हैं, अगर वे आपके आवेदन के संदर्भ में समझ में आते हैं।
  • वर्गीकरण-तब-प्रतिगमन ट्रिक का उपयोग करें: एक पहला नेटवर्क केवल यह निर्धारित करता है कि एक में होना चाहिए, कहते हैं, 10 डिब्बे, जैसे कि , आदि। एक दूसरा नेटवर्क तब सुधार की गणना करता है : यहां भी मदद मिल सकती है और सामान्य हो सकती है। बिना कोशिश किए नहीं कह सकता।y[80,97],[97,124][0,27]
  • एक पुराने के बजाय एक आधुनिक वास्तुकला (इन्सेप्शन या रेसनेट) का उपयोग करने का प्रयास करें। ResNet वास्तव में VGG-net की तुलना में कम पैरामीटर है। बेशक, आप यहां छोटे ResNets का उपयोग करना चाहते हैं - मुझे नहीं लगता कि ResNet-101 5000 छवियों के डेटा सेट पर मदद कर सकता है। आप डेटा सेट को बहुत बढ़ा सकते हैं, हालाँकि ...।
  • चूँकि आपका आउटपुट घूर्णन के लिए अपरिवर्तनीय है, इसलिए एक और बढ़िया विचार यह होगा कि समूह तुल्यकालिक CNNs का उपयोग किया जाए , जिसका आउटपुट (जब क्लासिफायर के रूप में उपयोग किया जाता है) घूर्णन को असतत करने के लिए अपरिवर्तनीय होता है , या सुगम CNNsजिसका उत्पादन निरंतर घुमावों के लिए अपरिवर्तनीय है। आक्रमणकारी संपत्ति आपको बहुत कम डेटा वृद्धि के साथ अच्छे परिणाम प्राप्त करने की अनुमति देती है, या आदर्श रूप से कोई भी नहीं (इसके लिए जो रोटेशन की चिंता करता है: बेशक आपको अभी भी अन्य प्रकार के दा की आवश्यकता है)। ग्रुप इक्वेरिएंट CNNs कार्यान्वयन के दृष्टिकोण से स्टीयरेबल CNN से अधिक परिपक्व होते हैं, इसलिए मैं पहले समूह CNNs की कोशिश करूँगा। आप वर्गीकरण भाग के लिए जी-सीएनएन का उपयोग करके, वर्गीकरण-तब-प्रतिगमन की कोशिश कर सकते हैं, या आप शुद्ध प्रतिगमन दृष्टिकोण के साथ प्रयोग कर सकते हैं। तदनुसार शीर्ष परत को बदलने के लिए याद रखें।
  • बैच के आकार के साथ प्रयोग (हाँ, हाँ, मुझे पता है कि हाइपरपरमेटर्स-हैकिंग शांत नहीं है, लेकिन यह सबसे अच्छा है जो मैं सीमित समय सीमा में और मुफ्त में आ सकता है :-)
  • अंत में, ऐसे आर्किटेक्चर हैं जिन्हें विशेष रूप से छोटे डेटा सेट के साथ सटीक पूर्वानुमान बनाने के लिए विकसित किया गया है। उनमें से ज्यादातर ने पतले संकल्पों का उपयोग किया : एक प्रसिद्ध उदाहरण मिश्रित पैमाने पर घने विक्षेपीय तंत्रिका नेटवर्क है । कार्यान्वयन तुच्छ नहीं है, हालांकि।

3
विस्तृत उत्तर के लिए धन्यवाद। मैं पहले से ही महत्वपूर्ण डेटा वृद्धि कर रहा था। इंसेप्शन मॉडल (जहाँ भिन्नता का अर्थ है कि फ़िल्टर की संख्या पूरे मॉडल में समान रूप से मापी गई है) के कुछ रूपों की कोशिश की। अविश्वसनीय सुधार देखा। अभी भी जाने के लिए एक रास्ता है। मैं आपके कुछ सुझावों की कोशिश करूँगा। एक बार फिर धन्यवाद।
रॉडरिगो-सिल्वेरा

@ rodrigo-silveira आपका स्वागत है, मुझे बताएं कि यह कैसे जाता है। हो सकता है कि परिणाम आने के बाद हम आपसे बात कर सकें।
डेल्टिव

1
शानदार उत्तर, और अधिक ^ के योग्य हैं
गिली

1
बहुत बढ़िया रचना!
कार्तिक त्यागराज

1
अगर मैं कर सकता तो मैं आपको इसके लिए 10k अंक देता। अद्भुत जवाब
Boppity Bop
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.