39 Star 235 Fork 71

OpenDocCN / pytorch-doc-zh

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
20.md 6.04 KB
一键复制 编辑 原始数据 按行查看 历史
布客飞龙 提交于 2021-10-14 22:51 . 2021-10-14 22:51:06

torch.nn.init

torch.nn.init.calculate_gain(nonlinearity,param=None) 

返回给定非线性函数的推荐增益值。值如下:

非线性 获得
linear 1
conv{1,2,3}d 1
sigmoid 1
tanh 5/3
relu sqrt(2)
leaky_relu sqrt(2/(1+negative_slope^2))

参数:

  1. nonlinearity - 非线性函数(nn.functional 名称)
  2. param - 非线性函数的可选参数

例子:

gain = nn.init.gain('leaky_relu') 
torch.nn.init.uniform(tensor, a=0, b=1)[source] 

从均匀分布 U(a, b)中生成值,填充输入的张量或变量

参数:

  1. tensor - n 维的 torch.Tensor 或者 autograd.Variable
  2. a - 均匀分布的下限
  3. b - 均匀分布的上限

例子:

w = torch.Tensor(3, 5)
print nn.init.uniform(w)
# 输出: 
# 0.0470  0.9742  0.9736  0.7976  0.1219
# 0.9390  0.7575  0.9370  0.4786  0.8396
# 0.1849  0.5384  0.0625  0.3719  0.1739
# [torch.FloatTensor of size 3x5] 
torch.nn.init.normal(tensor, mean=0, std=1) 

从给定均值和标准差的正态分布 N(mean, std)中生成值,填充输入的张量或变量

参数:

  1. tensor – n 维的 torch.Tensor 或者 autograd.Variable
  2. mean – 正态分布的平均值
  3. std – 正态分布的标准偏差

例子:

w = torch.Tensor(3, 5)
print torch.nn.init.normal(w) 
torch.nn.init.constant(tensor, val) 

使用值 val 填充输入 Tensor 或 Variable 。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. val – 填充张量的值

例子:

w = torch.Tensor(3, 5)
print torch.nn.init.constant(w) 
torch.nn.init.eye(tensor) 

用单位矩阵来填充 2 维输入张量或变量。在线性层尽可能多的保存输入特性。

参数:

  1. tensor – 2 维的 torch.Tensor 或 autograd.Variable

例子:

w = torch.Tensor(3, 5)
print torch.nn.init.eye(w) 
torch.nn.init.dirac(tensor) 

Dirac delta函数来填充{3, 4, 5}维输入张量或变量。在卷积层尽可能多的保存输入通道特性。

参数:

  1. tensor – {3, 4, 5}维的 torch.Tensor 或 autograd.Variable

例子:

w = torch.Tensor(3, 16, 5, 5)
print torch.nn.init.dirac(w) 
torch.nn.init.xavier_uniform(tensor, gain=1) 

根据 Glorot, X.和 Bengio, Y.在"理解难度训练深前馈神经网络"中描述的方法,使用均匀分布,填充张量或变量。结果张量中的值采样自 U(-a, a),其中 a= gain _sqrt( 2/(fan_in + fan_out))_sqrt(3). 该方法也被称为glorot的初始化。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. gain - 可选的缩放因子

例子:

w = torch.Tensor(3, 5)
print torch.nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu')) 
torch.nn.init.xavier_normal(tensor, gain=1) 

根据 Glorot, X.和 Bengio, Y. 于 2010 年在"理解难度训练深前馈神经网络"中描述的方法,采用正态分布,填充张量或变量。结果张量中的值采样自均值为 0,标准差为 gain * sqrt(2/(fan_in + fan_out))的正态分布。该方法也被称为glorot的初始化。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. gain - 可选的缩放因子

例子:

>>> w = torch.Tensor(3, 5)
>>> nn.init.xavier_normal(w) 
torch.nn.init.kaiming_uniform(tensor, a=0, mode='fan_in') 

根据 He, K 等人于 2015 年在"深入研究了超越人类水平的性能:整流器在 ImageNet 分类"中描述的方法,采用正态分布,填充张量或变量。结果张量中的值采样自 U(-bound, bound),其中 bound = sqrt(2/((1 + a^2) fan_in)) sqrt(3)。该方法也被称为He的初始化。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. a -此层后使用的整流器的负斜率(默认为 ReLU 为 0)
  3. mode - "fan_in"(默认)或"fan_out"。"fan_in"保留正向传播时权值方差的量级,"fan_out"保留反向传播时的量级。

例子:

w = torch.Tensor(3, 5)
torch.nn.init.kaiming_uniform(w, mode='fan_in') 
torch.nn.init.kaiming_normal(tensor, a=0, mode='fan_in') 

根据 He, K 等人 2015 年在"深入研究了超越人类水平的性能:整流器在 ImageNet 分类"中描述的方法,采用正态分布,填充张量或变量。结果张量中的值采样自均值为 0,标准差为 sqrt(2/((1 + a^2) * fan_in))的正态分布。该方法也被称为He的初始化。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. a -此层后使用的整流器的负斜率(默认为 ReLU 为 0)
  3. mode - "fan_in"(默认)或"fan_out"。"fan_in"保留正向传播时权值方差的量级,"fan_out"保留反向传播时的量级。
w = torch.Tensor(3, 5)
print torch.nn.init.kaiming_normal(w, mode='fan_out') 
torch.nn.init.orthogonal(tensor, gain=1) 

使用(半)正交矩阵填充输入张量或变量,参考 Saxe,A.等人 2013 年"深深度线性神经网络学习的非线性动力学的精确解"。输入张量必须至少是 2 维的,对于更高维度的张量,超出的维度会被展平。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable,其中 n>=2
  2. gain - 可选缩放因子

例子:

w = torch.Tensor(3, 5)
print torch.nn.init.orthogonal(w) 
torch.nn.init.sparse(tensor, sparsity, std=0.01) 

将二维输入张量或变为稀疏矩阵的非零元素,其中非零元素根据一个均值为 0,标准差为 std 的正态分布生成。如"深度学习通过 Hessian 免费优化"- Martens,J.(2010)。

参数:

  1. tensor – n 维的 torch.Tensor 或 autograd.Variable
  2. sparsity - 每列中需要被设置成零的元素比例
  3. std - 用于生成的正态分布的标准差
  4. non-zero values (the) – 例子:
w = torch.Tensor(3, 5)
print torch.nn.init.sparse(w, sparsity=0.1) 

译者署名

用户名 头像 职能 签名
Song 翻译 人生总要追求点什么
1
https://gitee.com/OpenDocCN/pytorch-doc-zh.git
git@gitee.com:OpenDocCN/pytorch-doc-zh.git
OpenDocCN
pytorch-doc-zh
pytorch-doc-zh
master

搜索帮助