PyTorch में प्रशिक्षित मॉडल को बचाने का सबसे अच्छा तरीका?


193

मैं PyTorch में एक प्रशिक्षित मॉडल को बचाने के लिए वैकल्पिक तरीकों की तलाश कर रहा था। अब तक, मुझे दो विकल्प मिल गए हैं।

  1. एक मॉडल को बचाने के लिए torch.save () और एक मॉडल को लोड करने के लिए torch.load ()
  2. model.state_dict () एक प्रशिक्षित मॉडल को बचाने के लिए और मॉडल को सहेजने के लिए model.load_state_dict ()

मैं इस चर्चा में आया हूं कि दृष्टिकोण 2 को दृष्टिकोण 1 से अधिक की सिफारिश की गई है।

मेरा सवाल यह है कि दूसरा दृष्टिकोण क्यों पसंद किया जाता है? क्या यह केवल इसलिए है क्योंकि torch.nn मॉड्यूल में वे दो कार्य हैं और हमें उनका उपयोग करने के लिए प्रोत्साहित किया जाता है?


2
मुझे लगता है कि यह इसलिए है क्योंकि torch.save () सभी मध्यवर्ती चरों को बचाती है, जैसे कि पीछे के उपयोग के लिए मध्यवर्ती आउटपुट। लेकिन आपको केवल मॉडल मापदंडों को बचाने की आवश्यकता है, जैसे वजन / पूर्वाग्रह आदि। कभी-कभी पूर्व बाद की तुलना में बहुत बड़ा हो सकता है।
दावई यांग

2
मैं परीक्षण किया torch.save(model, f)और torch.save(model.state_dict(), f)। सहेजी गई फ़ाइलों का आकार समान है। अब मैं उलझन में हूं। इसके अलावा, मैंने मॉडल.स्टेट_डिक्ट () को बचाने के लिए अचार का उपयोग किया। मुझे लगता torch.save(model.state_dict(), f)है कि मॉडल के निर्माण को संभालने के बाद से उपयोग करने का सबसे अच्छा तरीका है , और मशाल मॉडल के भार को संभालती है, इस प्रकार संभावित मुद्दों को समाप्त कर देती है। संदर्भ: चर्चा। Pytorch.org/t/saving-torch-models/838/4
दावी यांग

PyTorch की तरह लगता है कि यह उनके ट्यूटोरियल अनुभाग में थोड़ी अधिक स्पष्ट रूप से संबोधित किया गया है- बहुत अच्छी जानकारी है कि यहां जवाबों में सूचीबद्ध नहीं है, जिसमें एक समय में एक से अधिक मॉडल को सहेजना और गर्म शुरुआती मॉडल शामिल हैं।
whlteXbread

उपयोग करने में क्या गलत है pickle?
चार्ली पार्कर

1
@CharlieParker torch.save अचार पर आधारित है। निम्नलिखित लिंक ऊपर दिए गए ट्यूटोरियल से है: "[torch.save] पायथन के अचार मॉड्यूल का उपयोग करके पूरे मॉड्यूल को बचाएगा। इस दृष्टिकोण का नुकसान यह है कि क्रमबद्ध डेटा विशिष्ट वर्गों और सटीक निर्देशिका संरचना से बाध्य है जब मॉडल का उपयोग किया जाता है। सहेजा गया है। इसका कारण यह है क्योंकि अचार मॉडल वर्ग को ही नहीं बचाता है। बल्कि, यह वर्ग युक्त फ़ाइल के लिए एक पथ बचाता है, जो लोड समय के दौरान उपयोग किया जाता है। इस वजह से, आपका कोड विभिन्न तरीकों से टूट सकता है जब। अन्य परियोजनाओं में या रिफ्लेक्टर के बाद उपयोग किया जाता है। ”
डेविड मिलर

जवाबों:


215

मुझे यह पृष्ठ उनके गिथुब रेपो में मिला है, मैं यहां केवल सामग्री पेस्ट करूंगा।


एक मॉडल को बचाने के लिए अनुशंसित दृष्टिकोण

एक मॉडल को क्रमबद्ध करने और पुनर्स्थापित करने के लिए दो मुख्य दृष्टिकोण हैं।

पहला (अनुशंसित) केवल मॉडल पैरामीटर बचाता और लोड करता है:

torch.save(the_model.state_dict(), PATH)

फिर बाद में:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

दूसरा पूरा मॉडल बचाता और लोड करता है:

torch.save(the_model, PATH)

फिर बाद में:

the_model = torch.load(PATH)

हालांकि इस मामले में, क्रमबद्ध डेटा विशिष्ट वर्गों और उपयोग की जाने वाली सटीक निर्देशिका संरचना से बंधा हुआ है, इसलिए यह अन्य परियोजनाओं में उपयोग किए जाने पर या कुछ गंभीर रिफैक्टर्स के बाद विभिन्न तरीकों से टूट सकता है।


8
@Smth के अनुसार discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/... डिफ़ॉल्ट रूप से मॉडल को प्रशिक्षित करने के मॉडल पुनः लोड। इसलिए मैन्युअल रूप से the_model.eval () को लोड करने के बाद कॉल करने की आवश्यकता है, यदि आप इसे अनुमान के लिए लोड कर रहे हैं, तो प्रशिक्षण फिर से शुरू न करें।
22

दूसरी विधि stackoverflow.com/questions/53798009// खिड़कियों पर त्रुटि देती है । 10. इसे हल करने में सक्षम नहीं थी
गुलज़ार

क्या मॉडल वर्ग की पहुँच के लिए आवश्यकता के बिना बचत करने का कोई विकल्प है?
माइकल डी

उस दृष्टिकोण के साथ आप कैसे लोड मामले के लिए पारित करने के लिए आवश्यक * आर्ग और ** kwargs का ट्रैक रखते हैं?
मरिआनो काम्प

उपयोग करने में क्या गलत है pickle?
चार्ली पार्कर

144

यह आप पर निर्भर करता है की आप क्या करना चाहते हो।

केस # 1: मॉडल को अपने बचाव के लिए उपयोग करने के लिए सहेजें : आप मॉडल को सहेजते हैं, आप इसे पुनर्स्थापित करते हैं, और फिर आप मॉडल को मूल्यांकन मोड में बदलते हैं। यह इसलिए किया जाता है क्योंकि आपके पास आमतौर पर BatchNormऔर Dropoutपरतें होती हैं जो डिफ़ॉल्ट रूप से निर्माण पर ट्रेन मोड में होती हैं:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

केस # 2: बाद में प्रशिक्षण को फिर से शुरू करने के लिए मॉडल को सहेजें : यदि आपको उस मॉडल को प्रशिक्षित करने की आवश्यकता है जिसे आप सहेजने वाले हैं, तो आपको केवल मॉडल से अधिक बचत करने की आवश्यकता है। आपको अनुकूलक, युग, स्कोर, आदि की स्थिति को बचाने की आवश्यकता है। आप इसे इस तरह से करेंगे:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

प्रशिक्षण को फिर से शुरू करने के लिए आप निम्न चीजें करेंगे: state = torch.load(filepath)और फिर, प्रत्येक व्यक्तिगत वस्तु की स्थिति को पुनर्स्थापित करने के लिए, कुछ इस तरह से:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

चूंकि आप प्रशिक्षण फिर से शुरू कर रहे हैं, लोड करते समय राज्यों को पुनर्स्थापित करने के बाद कॉल करें model.eval()

केस # 3: आपके कोड तक पहुंच के बिना किसी और द्वारा उपयोग किया जाने वाला मॉडल : Tensorflow में आप एक ऐसी .pbफाइल बना सकते हैं जो मॉडल के आर्किटेक्चर और भार दोनों को परिभाषित करती है। यह बहुत उपयोगी है, विशेष रूप से उपयोग करते समय Tensorflow serve। Pytorch में ऐसा करने का समान तरीका होगा:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

यह तरीका अभी भी बुलेट प्रूफ नहीं है और चूँकि अभी भी कई बदलाव हो रहे हैं, मैं इसकी सिफारिश नहीं करूँगा।


1
क्या 3 मामलों के लिए एक अनुशंसित फ़ाइल समाप्त हो रही है? या यह हमेशा .pth है?
वेरना हाउंस्च्मिड

1
केस # 3 में torch.loadसिर्फ एक आदेश दिया गया है। भविष्यवाणियां करने के लिए आपको मॉडल कैसे मिलता है?
अल्बर्ट 8295

नमस्ते, क्या मुझे पता है कि "केस # 2: बाद में प्रशिक्षण फिर से शुरू करने के लिए मॉडल सहेजें" का उल्लेख कैसे किया जाए? मैं चेकपॉइंट को मॉडल में लोड करने में कामयाब रहा, फिर मैं "Model.to (डिवाइस) मॉडल = train_model_epoch (मॉडल, मानदंड, अनुकूलक, अनुसूची, युगीन) जैसे ट्रेन मॉडल को चलाने या फिर से शुरू करने में असमर्थ"
dzz

1
हाय, एक मामले में जो कि अनुमान के लिए है, आधिकारिक पाइरॉच डॉक्टर में कहते हैं कि या तो अनुमान या प्रशिक्षण पूरा करने के लिए ऑप्टिमाइज़र स्टेट_डिक्ट को बचाएं। "जब सामान्य चेकपॉइंट को सहेजना, या तो इंजेक्शन या फिर से शुरू करने के प्रशिक्षण के लिए उपयोग किया जाता है, तो आपको केवल मॉडल के State_dict से अधिक सहेजना होगा। ऑप्टिमाइज़र के State_dict को भी सहेजना महत्वपूर्ण है, क्योंकि इसमें बफ़र और पैरामीटर होते हैं जिन्हें मॉडल ट्रेनों के रूप में अपडेट किया जाता है। । "
मोहम्मद अवनी

1
# 3 के मामले में, मॉडल वर्ग को कहीं परिभाषित किया जाना चाहिए।
माइकल डी

12

अचार serializing और एक अजगर वस्तु de-serializing के लिए अजगर पुस्तकालय औजार द्विआधारी प्रोटोकॉल।

जब आप import torch(या जब आप PyTorch का उपयोग करते हैं) तो यह import pickleआपके लिए होगा और आपको कॉल करने की आवश्यकता नहीं है pickle.dump()और pickle.load()सीधे, जो कि ऑब्जेक्ट को बचाने और लोड करने के तरीके हैं।

वास्तव में, torch.save()और torch.load()लपेटो जाएगा pickle.dump()और pickle.load()आप के लिए।

state_dictउल्लेख किया गया एक अन्य उत्तर केवल कुछ और नोटों का हकदार है।

state_dictPyTorch के अंदर हमारे पास क्या है? वास्तव में दो state_dictएस हैं।

PyTorch मॉडल है torch.nn.Moduleहै model.parameters()learnable मानकों (डब्ल्यू और ख) प्राप्त करने के लिए कॉल। एक बार बेतरतीब ढंग से सेट होने पर ये सीखने योग्य पैरामीटर, समय के साथ जैसे हम सीखते हैं, वैसे ही अपडेट होते जाएंगे। सीखने योग्य पैरामीटर पहले हैं state_dict

दूसरा state_dictआशावादी राज्य है। आपको याद है कि हमारे सीखने योग्य मापदंडों को बेहतर बनाने के लिए अनुकूलक का उपयोग किया जाता है। लेकिन आशावादी state_dictतय है। वहां सीखने के लिए कुछ भी नहीं।

क्योंकि state_dictऑब्जेक्ट्स पायथन डिक्शनरी हैं, उन्हें आसानी से बचाया जा सकता है, अपडेट किया जा सकता है, बदल दिया जा सकता है, और बहाल कर दिया जा सकता है, जो PyTorch मॉडल और ऑप्टिमाइज़र के लिए बहुत अधिक मात्रा में जोड़ देता है।

आइए इसे समझाने के लिए एक सुपर सरल मॉडल बनाएं:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

यह कोड निम्नलिखित आउटपुट देगा:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

ध्यान दें कि यह एक न्यूनतम मॉडल है। आप अनुक्रमिक के ढेर को जोड़ने की कोशिश कर सकते हैं

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

ध्यान दें कि केवल सीखने योग्य मापदंडों (कंसिस्टेंट लेयर्स, लीनियर लेयर्स इत्यादि) और रजिस्टर्ड बफ़र्स (बैटकमर्म लेयर्स) की परतों में मॉडल की प्रविष्टियाँ हैं state_dict

गैर-सीखने योग्य चीजें, ऑप्टिमाइज़र ऑब्जेक्ट से संबंधित होती हैं state_dict, जिसमें ऑप्टिमाइज़र के राज्य के बारे में जानकारी होती है, साथ ही साथ हाइपरपरमेटर्स का उपयोग किया जाता है।

बाकी कहानी वही है; पूर्वानुमान चरण (यह एक चरण है जब हम प्रशिक्षण के बाद मॉडल का उपयोग करते हैं) भविष्यवाणी के लिए; हमारे द्वारा सीखे गए मापदंडों के आधार पर हम भविष्यवाणी करते हैं। तो अनुमान के लिए, हमें बस मापदंडों को बचाने की जरूरत है model.state_dict()

torch.save(model.state_dict(), filepath)

और बाद में उपयोग करने के लिए model.load_state_dict (torch.load (filepath)) model.eval ()

नोट: पिछली पंक्ति को मत भूलना model.eval()यह मॉडल लोड करने के बाद महत्वपूर्ण है।

इसके अलावा बचाने की कोशिश मत करो torch.save(model.parameters(), filepath)model.parameters()सिर्फ जनरेटर वस्तु है।

दूसरी तरफ, torch.save(model, filepath)मॉडल ऑब्जेक्ट को स्वयं बचाता है, लेकिन ध्यान रखें कि मॉडल में ऑप्टिमाइज़र नहीं है state_dict। अनुकूलक के राज्य को बचाने के लिए @Jadiel de Armas द्वारा अन्य उत्कृष्ट उत्तर की जाँच करें।


हालांकि यह एक सीधा समाधान नहीं है, समस्या का सार गहराई से विश्लेषण किया गया है! वोट दें।
जेसन यंग

7

एक आम PyTorch सम्मेलन एक .pt या .pth फ़ाइल एक्सटेंशन का उपयोग करके मॉडल को बचाने के लिए है।

सहेजें / लोड संपूर्ण मॉडल सहेजें:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

भार:

मॉडल वर्ग को कहीं न कहीं परिभाषित किया जाना चाहिए

model = torch.load(PATH)
model.eval()

4

यदि आप मॉडल को बचाना चाहते हैं और बाद में प्रशिक्षण फिर से शुरू करना चाहते हैं:

एकल GPU: सहेजें:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

भार:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

मल्टीपल जीपीयू: सेव करें

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

भार:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.