代码拉取完成,页面将自动刷新
import torch
from torch.nn import init
from apex._autocast_utils import _cast_if_autocast_enabled
import fast_layer_norm
class FastLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x, gamma, beta, epsilon, memory_efficient):
ctx.x_shape = x.shape
ctx.memory_efficient = memory_efficient
x = x.contiguous()
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon)
if ctx.memory_efficient:
ctx.save_for_backward(ymat, gamma, None, rsigma, beta)
else:
ctx.save_for_backward(xmat, gamma, mu, rsigma, None)
return ymat.view(x.shape)
@staticmethod
def backward(ctx, dy):
# assert dy.is_contiguous()
dy = dy.contiguous() # this happens!
x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors
dymat = dy.view(x_or_y_mat.shape)
dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient)
dx = dxmat.view(ctx.x_shape)
return dx, dgamma, dbeta, None, None
def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient):
args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient)
with torch.cuda.amp.autocast(enabled=False):
return FastLayerNormFN.apply(*args)
class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5, memory_efficient=False):
super().__init__()
self.epsilon = eps
self.memory_efficient = memory_efficient
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.bias = torch.nn.Parameter(torch.empty(hidden_size))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x):
return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。