हार के मूल्य के आधार पर केरस को प्रशिक्षण कैसे रोकें?


82

वर्तमान में मैं निम्नलिखित कोड का उपयोग करता हूं:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

यह बताता है कि जब प्रशिक्षण में 2 युगों तक सुधार नहीं हुआ तो केर को रोकना पड़ा। लेकिन मैं कुछ निरंतर "टीएचआर" से कम होने के बाद प्रशिक्षण रोकना चाहता हूं:

if val_loss < THR:
    break

मैंने दस्तावेज़ीकरण में देखा है कि आपके स्वयं के कॉलबैक करने की संभावना है: http://keras.io/callbacks/ लेकिन प्रशिक्षण प्रक्रिया को रोकने के लिए कुछ भी नहीं मिला। मुझे एक सलाह की जरूरत है।

जवाबों:


85

मुझे जवाब मिल गया। मैंने केरस स्रोतों में देखा और अर्लीस्टॉपिंग के लिए कोड का पता लगाया। मैंने अपना कॉलबैक बनाया, उसके आधार पर:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

और उपयोग:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
बस अगर यह किसी के लिए उपयोगी होगा - मेरे मामले में मैंने मॉनिटर = 'नुकसान' का इस्तेमाल किया, तो यह अच्छी तरह से काम किया।
QtRoS

15
ऐसा लगता है कि केर को अपडेट कर दिया गया है। EarlyStopping कॉलबैक फ़ंक्शन इसे में अब बनाया min_delta है। अब स्रोत कोड हैक करने की कोई जरूरत नहीं है, याय! stackoverflow.com/a/41459368/3345375
jkdev

3
प्रश्न और उत्तरों को फिर से पढ़ने पर, मुझे अपने आप को सही करने की आवश्यकता है: min_delta का अर्थ है "यदि प्रति युग (या कई युगों के अनुसार) पर्याप्त सुधार नहीं हुआ है तो जल्दी रुकें।" हालांकि, ओपी ने पूछा कि "नुकसान को एक निश्चित स्तर से कम होने पर जल्दी कैसे रोकें।"
jkdev

NameError: 'Callback' नाम परिभाषित नहीं है ... मैं इसे कैसे ठीक करूँगा?
एलिसाएलियाह

2
एलियाह यह कोशिश करते हैं: from keras.callbacks import Callback
ZFTurbo

26

Keras.callbacks.EarlyStopping कॉलबैक में एक min_delta तर्क है। करेस प्रलेखन से:

min_delta: सुधार के रूप में अर्हता प्राप्त करने के लिए निगरानी की गई मात्रा में न्यूनतम परिवर्तन, अर्थात min_delta से कम का पूर्ण परिवर्तन, कोई सुधार नहीं होगा।


3
संदर्भ के लिए, यहाँ Keras (1.1.0) के पुराने संस्करण के लिए डॉक्स हैं, जिसमें min_delta तर्क अभी तक शामिल नहीं था: faroit.github.io/keras-docs/1.1.0/callback/#earlystopping
jkdev

मैं इसे तब तक कैसे नहीं रोक सकता जब तक यह min_deltaकई युगों तक बना रहता है?
zyxue

अर्लीस्टॉपिंग नामक धैर्य के लिए एक और पैरामीटर है: कोई सुधार के साथ युगों की संख्या जिसके बाद प्रशिक्षण रोक दिया जाएगा।
डेविन

13

एक समाधान के लिए model.fit(nb_epoch=1, ...)एक लूप के अंदर कॉल करना है, फिर आप लूप के लिए एक ब्रेक स्टेटमेंट डाल सकते हैं और जो चाहें अन्य कस्टम नियंत्रण प्रवाह कर सकते हैं।


यह अच्छा होगा यदि वे एक कॉलबैक करते हैं जो एक एकल फ़ंक्शन में लेता है जो ऐसा कर सकता है।
ईमानदारी ६

8

मैंने कस्टम कॉलबैक का उपयोग करके उसी समस्या को हल किया।

निम्नलिखित कस्टम कॉलबैक कोड में THR को उस मूल्य के साथ असाइन करें जिस पर आप प्रशिक्षण रोकना चाहते हैं और अपने मॉडल में कॉलबैक जोड़ें।

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

जब मैं अभ्यास विशेषज्ञता में TensorFlow ले रहा था , मैंने एक बहुत ही सुंदर तकनीक सीखी। स्वीकृत उत्तर से थोड़ा संशोधित।

आइए हमारे पसंदीदा MNIST डेटा के साथ उदाहरण सेट करें।

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

इसलिए, यहां मैंने सेट किया metrics=['accuracy'], और इस तरह कॉलबैक क्लास में स्थिति सेट हो गई'accuracy'> 0.90

आप किसी भी मीट्रिक को चुन सकते हैं और इस उदाहरण की तरह प्रशिक्षण की निगरानी कर सकते हैं। सबसे महत्वपूर्ण बात यह है कि आप अलग-अलग मीट्रिक के लिए अलग-अलग स्थिति निर्धारित कर सकते हैं और उन्हें एक साथ उपयोग कर सकते हैं।

उम्मीद है कि यह मदद करता है!


समारोह का नाम on_epoch_end
xarion

0

मेरे लिए मॉडल केवल प्रशिक्षण रोक देगा यदि मैंने स्टॉप_ट्रेनिंग पैरामीटर को True पर सेट करने के बाद रिटर्न स्टेटमेंट जोड़ा क्योंकि मैं self.model.ev मूल्यांकन के बाद कॉल कर रहा था। इसलिए या तो stop_training डालना सुनिश्चित करें = फ़ंक्शन के अंत में सही है या रिटर्न स्टेटमेंट जोड़ें।

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

यदि आप एक कस्टम ट्रेनिंग लूप का उपयोग कर रहे हैं, तो आप collections.dequeएक "रोलिंग" सूची का उपयोग कर सकते हैं , जिसे जोड़ा जा सकता है, और जब सूची लंबी होती है, तो बाएं हाथ की चीजें पॉप आउट हो जाती हैं maxlen। यहाँ लाइन है:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

यहाँ एक पूर्ण उदाहरण है:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.