Pytorch में torch.no_grad का क्या उपयोग है?


22

मैं pytorch में नया हूं और इस github कोड के साथ शुरू हुआ हूं । मुझे कोड में लाइन 60-61 में टिप्पणी समझ नहीं आ रही है "because weights have requires_grad=True, but we don't need to track this in autograd"। मैं समझ गया कि हम requires_grad=Trueउन चरों का उल्लेख करते हैं जिन्हें हमें ऑटोग्राड के उपयोग के लिए ग्रेडिएंट की गणना करने की आवश्यकता है लेकिन इसका क्या मतलब है "tracked by autograd"?

जवाबों:


25

आवरण "टार्च.नो_ग्रेड ()" के साथ अस्थायी रूप से सभी आवश्यक_ग्रेड ध्वज को गलत पर सेट करता है। आधिकारिक PyTorch ट्यूटोरियल ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ) से एक उदाहरण :

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

बाहर:

True
True
False

मैं आपको उपरोक्त वेबसाइट से सभी ट्यूटोरियल पढ़ने की सलाह देता हूं।

आपके उदाहरण में: मुझे लगता है कि लेखक नहीं चाहता है कि PyTorch नए परिभाषित चर w1 और w2 के ग्रेडिएंट्स की गणना करें क्योंकि वह अपने मूल्यों को अपडेट करना चाहता है।


6
with torch.no_grad()

ब्लॉक में सभी प्रचालनों का कोई ग्रेडिएंट नहीं होगा।

Pytorch में, आप w1 और w2 के परिवर्तन को नहीं कर सकते, जो दो चर हैं require_grad = True। मुझे लगता है कि डब्ल्यू 1 और डब्ल्यू 2 के परिवर्तन से बचने के लिए है क्योंकि यह बैक प्रोगैगशन गणना में त्रुटि का कारण होगा। चूंकि विस्थापन परिवर्तन पूरी तरह से डब्ल्यू 1 और डब्ल्यू 2 बदल जाएगा।

हालाँकि, यदि आप इसका उपयोग करते हैं, तो आप no_grad()नए w1 को नियंत्रित कर सकते हैं और नए w2 के पास कोई ग्रेडिएंट नहीं है क्योंकि वे संचालन से उत्पन्न होते हैं, जिसका अर्थ है कि आप केवल w1 और w2 के मूल्य को बदलते हैं, ग्रेडिएंट भाग को नहीं, उनके पास अभी भी पिछले परिभाषित चर ढाल की जानकारी है और वापस प्रचार जारी रह सकता है।

हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.