GELU सक्रियण क्या है?


18

मैं BERT पेपर से गुज़र रहा था जो GELU (गौसियन एरर लाइनर यूनिट) का उपयोग करता है जो GELU ( रूप में समीकरण बताता है जो बदले में

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

क्या आप समीकरण को सरल बना सकते हैं और बता सकते हैं कि यह कैसे अनुमानित किया गया है।

जवाबों:


19

GELU फ़ंक्शन

हम निम्न प्रकार सेN(0,1) , यानी के संचयी वितरण का विस्तार कर सकते हैं: Φ(x)

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

ध्यान दें कि यह एक परिभाषा है , न कि एक समीकरण (या एक संबंध)। लेखकों ने इस प्रस्ताव के लिए कुछ औचित्य प्रदान किए हैं, उदाहरण के लिए एक स्टोकेस्टिक उपमा , हालांकि गणितीय रूप से, यह सिर्फ एक परिभाषा है।

यहाँ GELU की साजिश है:

तन्ह सन्निकटन

इस प्रकार के संख्यात्मक अंदाजों के लिए, मुख्य विचार एक समान फ़ंक्शन (मुख्य रूप से अनुभव पर आधारित) को खोजने, इसे मानकीकृत करने और फिर इसे मूल फ़ंक्शन से बिंदुओं के एक सेट पर फिट करने का है।

यह जानते हुए कि erf(x)tanh ( x ) के बहुत करीब हैtanh(x)

और एरफ का पहला व्युत्पन्न ( एक्सerf(x2)के उस के साथ मेल खाताtanh(2πx)मेंx=0है, जो2π , हम फिट करने के लिए आगे बढ़ना

tanh(2π(x+ax2+bx3+cx4+dx5))
(या अधिक शब्द) अंक का एक सेट के साथ(xi,erf(xi2))

मैंने इस फ़ंक्शन को (1.5,1.5) ( इस साइट का उपयोग करके ) के बीच 20 नमूनों में फिट किया है , और यहां गुणांक हैं:

की स्थापना करके a=c=d=0 , b होने का अनुमान था 0.04495641 । एक विस्तृत श्रृंखला से अधिक नमूनों के साथ (वह साइट केवल 20 की अनुमति है), गुणांक b कागज के 0.044715 करीब होगा । अंत में हम प्राप्त करते हैं

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

साथ मतलब वर्ग त्रुटि 108 के लिए x[10,10]

ध्यान दें कि यदि हमने पहले व्युत्पन्न शब्द, √ के बीच संबंध का उपयोग नहीं किया है2π

0.5x(1+tanh(0.797885x+0.035677x3))
रूप में मापदंडों में शामिल किया गया होगा जो कम सुंदर (कम विश्लेषणात्मक, अधिक संख्यात्मक) है!

समता का उपयोग

जैसा कि @BookYourLuck ने सुझाव दिया है , हम बहुपद के स्थान को सीमित करने के लिए कार्यों की समता का उपयोग कर सकते हैं जिसमें हम खोज करते हैं। अर्थात्, चूँकि erf एक विषम कार्य है, अर्थात f(x)=f(x) , और tanh भी एक विषम कार्य है, tanh के अंदर बहुपद फलन pol(x) भी विषम होना चाहिए (जिसमें केवल विषम शक्तियां होनी चाहिए) एक्स ) के लिए ERF ( - x ) tanh ( पोल ( - xtanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

पहले, हम x2 और x4 शक्तियों के लिए (लगभग) शून्य गुणांक के साथ समाप्त होने के लिए भाग्यशाली थे , हालांकि सामान्य तौर पर, यह निम्न गुणवत्ता के अनुमानों को जन्म दे सकता है, उदाहरण के लिए, 0.23x2 तरह एक शब्द है जिसे रद्द किया जा रहा है अतिरिक्त शर्तों (सम या विषम) के बजाय केवल 0x2 के लिए चुनने से ।

सिग्मोइड सन्निकटन

erf(x)2(σ(x)12)104x[10,10]

डेटा बिंदुओं को उत्पन्न करने, फ़ंक्शंस की फिटिंग करने और माध्य चुकता त्रुटियों की गणना करने के लिए यहाँ एक पायथन कोड है:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

आउटपुट:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
सन्निकटन की आवश्यकता क्यों है? क्या वे सिर्फ erf फ़ंक्शन का उपयोग नहीं कर सकते थे?
सेबीसीबी

8

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
0.044715

एक्स[-1,1]एक्स

tanh(एक्स)=एक्स-एक्स33+(एक्स3)
आर(एक्स)=2π(एक्स-एक्स33)+(एक्स3)
tanh(2π(एक्स+एक्स3))=2π(एक्स+(-23π)एक्स3)+(एक्स3)
आर(एक्स2)=2π(एक्स-एक्स36)+(एक्स3)
एक्स3
a0.04553992412
0.044715

हमारी साइट का प्रयोग करके, आप स्वीकार करते हैं कि आपने हमारी Cookie Policy और निजता नीति को पढ़ और समझा लिया है।
Licensed under cc by-sa 3.0 with attribution required.