मैं BERT पेपर से गुज़र रहा था जो GELU (गौसियन एरर लाइनर यूनिट) का उपयोग करता है जो GELU ( रूप में समीकरण बताता है
जो बदले में
क्या आप समीकरण को सरल बना सकते हैं और बता सकते हैं कि यह कैसे अनुमानित किया गया है।
मैं BERT पेपर से गुज़र रहा था जो GELU (गौसियन एरर लाइनर यूनिट) का उपयोग करता है जो GELU ( रूप में समीकरण बताता है
जो बदले में
क्या आप समीकरण को सरल बना सकते हैं और बता सकते हैं कि यह कैसे अनुमानित किया गया है।
जवाबों:
हम निम्न प्रकार से , यानी के संचयी वितरण का विस्तार कर सकते हैं:
ध्यान दें कि यह एक परिभाषा है , न कि एक समीकरण (या एक संबंध)। लेखकों ने इस प्रस्ताव के लिए कुछ औचित्य प्रदान किए हैं, उदाहरण के लिए एक स्टोकेस्टिक उपमा , हालांकि गणितीय रूप से, यह सिर्फ एक परिभाषा है।
यहाँ GELU की साजिश है:
इस प्रकार के संख्यात्मक अंदाजों के लिए, मुख्य विचार एक समान फ़ंक्शन (मुख्य रूप से अनुभव पर आधारित) को खोजने, इसे मानकीकृत करने और फिर इसे मूल फ़ंक्शन से बिंदुओं के एक सेट पर फिट करने का है।
यह जानते हुए कि tanh ( x ) के बहुत करीब है
और एरफ का पहला व्युत्पन्न ( एक्सके उस के साथ मेल खातामेंहै, जो , हम फिट करने के लिए आगे बढ़ना
मैंने इस फ़ंक्शन को ( इस साइट का उपयोग करके ) के बीच 20 नमूनों में फिट किया है , और यहां गुणांक हैं:
की स्थापना करके , होने का अनुमान था । एक विस्तृत श्रृंखला से अधिक नमूनों के साथ (वह साइट केवल 20 की अनुमति है), गुणांक कागज के करीब होगा । अंत में हम प्राप्त करते हैं
साथ मतलब वर्ग त्रुटि के लिए ।
ध्यान दें कि यदि हमने पहले व्युत्पन्न शब्द, √ के बीच संबंध का उपयोग नहीं किया है
जैसा कि @BookYourLuck ने सुझाव दिया है , हम बहुपद के स्थान को सीमित करने के लिए कार्यों की समता का उपयोग कर सकते हैं जिसमें हम खोज करते हैं। अर्थात्, चूँकि एक विषम कार्य है, अर्थात , और भी एक विषम कार्य है, tanh के अंदर बहुपद फलन भी विषम होना चाहिए (जिसमें केवल विषम शक्तियां होनी चाहिए) एक्स ) के लिए
ERF ( - x ) ≃ tanh ( पोल ( - x
पहले, हम और शक्तियों के लिए (लगभग) शून्य गुणांक के साथ समाप्त होने के लिए भाग्यशाली थे , हालांकि सामान्य तौर पर, यह निम्न गुणवत्ता के अनुमानों को जन्म दे सकता है, उदाहरण के लिए, तरह एक शब्द है जिसे रद्द किया जा रहा है अतिरिक्त शर्तों (सम या विषम) के बजाय केवल के लिए चुनने से ।
डेटा बिंदुओं को उत्पन्न करने, फ़ंक्शंस की फिटिंग करने और माध्य चुकता त्रुटियों की गणना करने के लिए यहाँ एक पायथन कोड है:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
आउटपुट:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05