tf.truncated_normal和tf.random_normal有什么区别?

问题描述 投票:42回答:3

tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)从正态分布输出随机值。

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)从截断的正态分布输出随机值。

我试着谷歌搜索'截断正态分布'。但是不太了解。

math machine-learning tensorflow
3个回答
62
投票

documentation说明了一切:对于截断的正态分布:

生成的值遵循具有指定平均值和标准偏差的正态分布,除了丢弃并重新选择幅度大于平均值2个标准偏差的值。

最有可能通过为自己绘制图形很容易理解差异(%magic是因为我使用了jupyter笔记本):

import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline  

n = 500000
A = tf.truncated_normal((n,))
B = tf.random_normal((n,))
with tf.Session() as sess:
    a, b = sess.run([A, B])

现在

plt.hist(a, 100, (-4.2, 4.2));
plt.hist(b, 100, (-4.2, 4.2));

enter image description here


使用截断法线的要点是克服像sigmoid这样的tome函数的饱和度(如果值太大/太小,则神经元停止学习)。


22
投票

tf.truncated_normal()从正态分布中选择随机数,该正态分布的均值接近于0且值接近于0.例如,从-0.1到0.1。它被称为截断,因为你从正常分布切断尾巴。

tf.random_normal()从正态分布中选择随机数,其均值接近于0,但值可以稍微分开。例如,从-2到2。

在机器学习中,实际上,您通常希望权重接近0。


8
投票

API documentation for tf.truncated_normal()将功能描述为:

从截断的正态分布输出随机值。

生成的值遵循具有指定平均值和标准偏差的正态分布,除了丢弃并重新选择幅度大于平均值2个标准偏差的值。

© www.soinside.com 2019 - 2024. All rights reserved.