2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
transformer_model.py 51.10 KB
一键复制 编辑 原始数据 按行查看 历史
LiangZhibo 提交于 2021-01-27 16:32 . Change GatherV2 to Gather
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transformer model."""
import math
import copy
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops.functional as F
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops.primitive import constexpr
from .beam_search import BeamSearchDecoder, TileBeam
from .weight_init import normal_weight, weight_variable
class TransformerConfig:
"""
Configuration for `Transformer`.
Args:
batch_size (int): Batch size of input dataset.
seq_length (int): Length of input sequence. Default: 128.
vocab_size (int): The shape of each embedding vector. Default: 36560.
hidden_size (int): Size of the layers. Default: 1024.
num_hidden_layers (int): Number of hidden layers in the Transformer encoder/decoder
cell. Default: 6.
num_attention_heads (int): Number of attention heads in the Transformer
encoder/decoder cell. Default: 16.
intermediate_size (int): Size of intermediate layer in the Transformer
encoder/decoder cell. Default: 4096.
hidden_act (str): Activation function used in the Transformer encoder/decoder
cell. Default: "relu".
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.3.
attention_probs_dropout_prob (float): The dropout probability for
MultiheadAttention. Default: 0.3.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 128.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
label_smoothing (float): label smoothing setting. Default: 0.1
beam_width (int): beam width setting. Default: 4
max_decode_length (int): max decode length in evaluation. Default: 80
length_penalty_weight (float): normalize scores of translations according to their length. Default: 1.0
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
compute_type (:class:`mindspore.dtype`): Compute type in Transformer. Default: mstype.float32.
"""
def __init__(self,
batch_size,
seq_length=128,
vocab_size=36560,
hidden_size=1024,
num_hidden_layers=6,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="relu",
hidden_dropout_prob=0.3,
attention_probs_dropout_prob=0.3,
max_position_embeddings=128,
initializer_range=0.02,
label_smoothing=0.1,
beam_width=4,
max_decode_length=80,
length_penalty_weight=1.0,
dtype=mstype.float32,
compute_type=mstype.float32):
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.label_smoothing = label_smoothing
self.beam_width = beam_width
self.max_decode_length = max_decode_length
self.length_penalty_weight = length_penalty_weight
self.dtype = dtype
self.compute_type = compute_type
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size))
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = P.Shape()
def construct(self, input_ids):
"""Get a embeddings lookup table with a fixed dictionary and size."""
input_shape = self.shape(input_ids)
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
out_shape = input_shape + (self.embedding_size,)
output = self.reshape(output_for_reshape, out_shape)
return output, self.embedding_table
def position_encoding(length,
depth,
min_timescale=1,
max_timescale=1e4):
"""
Create Tensor of sinusoids of different frequencies.
Args:
length (int): Length of the Tensor to create, i.e. Number of steps.
depth (int): Hidden size.
min_timescale (float): Default: 1.
max_timescale (float): Default: 10000.
Returns:
Tensor of shape (length, depth)
"""
depth = depth // 2
positions = np.arange(length, dtype=np.float32)
log_timescale_increment = (np.log(max_timescale / min_timescale) / (depth - 1))
inv_timescales = min_timescale * np.exp(np.arange(depth, dtype=np.float32) * -log_timescale_increment)
scaled_time = np.expand_dims(positions, 1) * np.expand_dims(inv_timescales, 0)
x = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
return x
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional embeddings to word embeddings.
Args:
embedding_size (int): The size of each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
max_position_embeddings (int): Maximum length of sequences used in this
model. Default: 128.
dropout_prob (float): The dropout probability. Default: 0.1.
"""
def __init__(self,
embedding_size,
use_one_hot_embeddings=False,
initializer_range=0.02,
max_position_embeddings=128,
dropout_prob=0.1):
super(EmbeddingPostprocessor, self).__init__()
self.scores_mul = Tensor([math.sqrt(float(embedding_size))], dtype=mstype.float32)
self.multiply = P.Mul()
self.add = P.Add()
self.dropout = nn.Dropout(1 - dropout_prob, dtype=mstype.float32)
self.use_dropout = dropout_prob > 0
self.expand_dims = P.ExpandDims()
self.position_embedding_table = Tensor(position_encoding(max_position_embeddings, embedding_size),
mstype.float32)
self.shape = P.Shape()
def construct(self, word_embeddings):
"""Postprocessors apply positional embeddings to word embeddings."""
input_shape = self.shape(word_embeddings)
input_len = input_shape[1]
output = self.multiply(word_embeddings, self.scores_mul)
# add position embeddings
position_embeddings = self.position_embedding_table[0:input_len:1, ::]
position_embeddings = self.expand_dims(position_embeddings, 0)
output = self.add(output, position_embeddings)
if self.use_dropout:
output = self.dropout(output)
return output
class CastWrapper(nn.Cell):
"""
Cast wrapper.
"""
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(CastWrapper, self).__init__()
self.cast = P.Cast()
self.dst_type = dst_type
def construct(self, x):
return self.cast(x, self.dst_type)
class LayerPreprocess(nn.Cell):
"""
preprocess input of each layer.
"""
def __init__(self,
in_channels=None):
super(LayerPreprocess, self).__init__()
self.layernorm = nn.LayerNorm((in_channels,))
self.cast = P.Cast()
self.get_dtype = P.DType()
def construct(self, input_tensor):
output = self.cast(input_tensor, mstype.float32)
output = self.layernorm(output)
output = self.cast(output, self.get_dtype(input_tensor))
return output
class LayerPostprocess(nn.Cell):
"""
postprocess ouput of each layer.
"""
def __init__(self,
dropout_prob=0.1):
super(LayerPostprocess, self).__init__()
self.add = P.Add()
self.dropout = nn.Dropout(1 - dropout_prob)
self.use_dropout = dropout_prob > 0
def construct(self, hidden_tensor, input_tensor):
output = hidden_tensor
if self.use_dropout:
output = self.dropout(output)
output = self.add(output, input_tensor)
return output
class MultiheadAttention(nn.Cell):
"""
Apply multi-headed attention from "from_tensor" to "to_tensor".
Args:
batch_size (int): Batch size of input datasets.
from_tensor_width (int): Size of last dim of from_tensor.
to_tensor_width (int): Size of last dim of to_tensor.
from_seq_length (int): Length of from_tensor sequence.
to_seq_length (int): Length of to_tensor sequence.
num_attention_heads (int): Number of attention heads. Default: 1.
size_per_head (int): Size of each attention head. Default: 512.
query_act (str): Activation function for the query transform. Default: None.
key_act (str): Activation function for the key transform. Default: None.
value_act (str): Activation function for the value transform. Default: None.
has_attention_mask (bool): Specifies whether to use attention mask. Default: False.
attention_probs_dropout_prob (float): The dropout probability for
MultiheadAttention. Default: 0.0.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
do_return_2d_tensor (bool): True for return 2d tensor. False for return 3d
tensor. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
from_tensor_width,
to_tensor_width,
out_tensor_width,
num_attention_heads=1,
size_per_head=512,
query_act=None,
key_act=None,
value_act=None,
out_act=None,
has_attention_mask=True,
attention_probs_dropout_prob=0.0,
use_one_hot_embeddings=False,
initializer_range=0.02,
do_return_2d_tensor=True,
compute_type=mstype.float32):
super(MultiheadAttention, self).__init__()
self.batch_size = batch_size
self.num_attention_heads = num_attention_heads
self.size_per_head = size_per_head
self.has_attention_mask = has_attention_mask
assert has_attention_mask
self.use_one_hot_embeddings = use_one_hot_embeddings
self.initializer_range = initializer_range
self.do_return_2d_tensor = do_return_2d_tensor
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width)
units = num_attention_heads * size_per_head
self.query_layer = nn.Dense(from_tensor_width,
units,
activation=query_act,
has_bias=False,
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type)
self.key_layer = nn.Dense(to_tensor_width,
units,
activation=key_act,
has_bias=False,
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.value_layer = nn.Dense(to_tensor_width,
units,
activation=value_act,
has_bias=False,
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type)
self.out_layer = nn.Dense(units,
out_tensor_width,
activation=out_act,
has_bias=False,
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type)
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
self.multiply = P.Mul()
self.transpose = P.Transpose()
self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul()
self.softmax = nn.Softmax()
self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)
self.use_dropout = attention_probs_dropout_prob > 0
if self.has_attention_mask:
self.expand_dims = P.ExpandDims()
self.sub = P.Sub()
self.add = P.Add()
self.cast = P.Cast()
self.get_dtype = P.DType()
self.cast_compute_type = CastWrapper(dst_type=compute_type)
self.softmax_cast = P.Cast()
def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None):
"""Apply multihead attention."""
from_seq_length = seq_length
to_seq_length = enc_seq_length
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head)
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head)
if self.do_return_2d_tensor:
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head)
if from_seq_length == -1:
shape_return = (-1, self.num_attention_heads * self.size_per_head)
else:
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head)
# reshape 2d/3d input tensors to 2d
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
query_out = self.query_layer(from_tensor_2d)
key_out = self.key_layer(to_tensor_2d)
value_out = self.value_layer(to_tensor_2d)
query_layer = self.reshape(query_out, shape_from)
query_layer = self.transpose(query_layer, self.trans_shape)
key_layer = self.reshape(key_out, shape_to)
key_layer = self.transpose(key_layer, self.trans_shape)
attention_scores = self.matmul_trans_b(query_layer, key_layer)
attention_scores = self.multiply(attention_scores, self.scores_mul)
if self.has_attention_mask:
attention_mask = self.expand_dims(attention_mask, 1)
multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)),
self.cast(attention_mask, self.get_dtype(attention_scores)))
adder = self.multiply(multiply_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
attention_scores = self.softmax_cast(attention_scores, mstype.float32)
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax_cast(attention_probs, self.get_dtype(key_layer))
if self.use_dropout:
attention_probs = self.dropout(attention_probs)
value_layer = self.reshape(value_out, shape_to)
value_layer = self.transpose(value_layer, self.trans_shape)
context_layer = self.matmul(attention_probs, value_layer)
context_layer = self.transpose(context_layer, self.trans_shape)
context_layer = self.reshape(context_layer, shape_return)
context_layer = self.out_layer(context_layer)
return context_layer
class SelfAttention(nn.Cell):
"""
Apply self-attention.
Args:
batch_size (int): Batch size of input dataset.
from_seq_length (int): Length of query sequence.
to_seq_length (int): Length of memory sequence.
hidden_size (int): Size of attention layers.
num_attention_heads (int): Number of attention heads. Default: 16.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
has_attention_mask (bool): Specifies whether has attention mask. Default: True.
is_encdec_att (bool): Specifies whether query sequence and memory sequence are different. Default: False.
compute_type (:class:`mindspore.dtype`): Compute type in MultiheadAttention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size,
num_attention_heads=16,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
has_attention_mask=True,
is_encdec_att=False,
compute_type=mstype.float32):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError("The hidden size (%d) is not a multiple of the number "
"of attention heads (%d)" % (hidden_size, num_attention_heads))
self.size_per_head = int(hidden_size / num_attention_heads)
self.is_encdec_att = is_encdec_att
self.attention = MultiheadAttention(
batch_size=batch_size,
from_tensor_width=hidden_size,
to_tensor_width=hidden_size,
out_tensor_width=hidden_size,
num_attention_heads=num_attention_heads,
size_per_head=self.size_per_head,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
has_attention_mask=has_attention_mask,
do_return_2d_tensor=True,
compute_type=compute_type)
self.preprocess = LayerPreprocess(in_channels=hidden_size)
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, memory_tensor, attention_mask, seq_length, enc_seq_length):
"""Apply self-attention."""
input_tensor = self.reshape(input_tensor, self.shape)
memory_tensor = self.reshape(memory_tensor, self.shape)
output = self.preprocess(input_tensor)
if not self.is_encdec_att:
memory_tensor = output
attention_output = self.attention(output, memory_tensor, seq_length, enc_seq_length, attention_mask)
output = self.postprocess(attention_output, input_tensor)
return output
class FeedForward(nn.Cell):
"""
Apply two-layer feed forward
Args:
in_channels (int): Size of the input layer.
hidden_size (int): Size of the hidden layer.
out_channels (int): Size of the output layers.
hidden_act (str): name of the activation function. Default: relu
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
compute_type (:class:`mindspore.dtype`): Compute type in FeedForward. Default: mstype.float32.
"""
def __init__(self,
in_channels,
hidden_size,
out_channels,
hidden_act="relu",
initializer_range=0.02,
hidden_dropout_prob=0.1,
compute_type=mstype.float32):
super(FeedForward, self).__init__()
self.conv1 = nn.Dense(in_channels,
hidden_size,
activation=hidden_act,
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type)
self.conv2 = nn.Dense(hidden_size,
out_channels,
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type)
self.preprocess = LayerPreprocess(in_channels=in_channels)
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob)
self.reshape = P.Reshape()
self.shape = (-1, in_channels)
self.dropout = nn.Dropout(1 - hidden_dropout_prob)
self.use_dropout = hidden_dropout_prob > 0
def construct(self, input_tensor):
input_tensor = self.reshape(input_tensor, self.shape)
output = self.preprocess(input_tensor)
output = self.conv1(output)
if self.use_dropout:
output = self.dropout(output)
output = self.conv2(output)
output = self.postprocess(output, input_tensor)
return output
class EncoderCell(nn.Cell):
"""
Encoder cells used in Transformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers. Default: 1024.
seq_length (int): Length of input sequence. Default: 128.
num_attention_heads (int): Number of attention heads. Default: 16.
intermediate_size (int): Size of intermediate layer. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.02.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.1.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
hidden_act (str): Activation function. Default: "relu".
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size=1024,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
hidden_act="relu",
compute_type=mstype.float32):
super(EncoderCell, self).__init__()
self.attention = SelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
is_encdec_att=False,
compute_type=compute_type)
self.feedforward = FeedForward(
in_channels=hidden_size,
hidden_size=intermediate_size,
out_channels=hidden_size,
hidden_act=hidden_act,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
def construct(self, hidden_states, attention_mask, seq_length):
# self-attention with ln, res
attention_output = self.attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
# feed forward with ln, res
output = self.feedforward(attention_output)
return output
class TransformerEncoder(nn.Cell):
"""
Multi-layer transformer encoder.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
seq_length (int): Length of input sequence.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 16.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1..
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size,
num_hidden_layers,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
hidden_act="relu",
compute_type=mstype.float32):
super(TransformerEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
self.batch_size = batch_size
self.hidden_size = hidden_size
layers = []
for _ in range(num_hidden_layers):
layer = EncoderCell(batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
def construct(self, input_tensor, attention_mask, seq_length):
"""Apply encoder."""
out_shape = (self.batch_size, seq_length, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask, seq_length)
prev_output = layer_output
prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, out_shape)
return output
class DecoderCell(nn.Cell):
"""
decoder cells used in Transformer.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the Transformer decoder layers. Default: 1024.
seq_length (int): Length of input sequence. Default: 128.
enc_seq_length (int): Length of source sentences. Default:128
num_attention_heads (int): Number of attention heads. Default: 12.
intermediate_size (int): Size of intermediate layer. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.02.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
hidden_act (str): Activation function. Default: "relu".
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size=1024,
num_attention_heads=12,
intermediate_size=4096,
attention_probs_dropout_prob=0.02,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
hidden_act="relu",
compute_type=mstype.float32):
super(DecoderCell, self).__init__()
self.self_attention = SelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
is_encdec_att=False,
hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
self.cross_attention = SelfAttention(
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
is_encdec_att=True,
hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
self.feedforward = FeedForward(
in_channels=hidden_size,
hidden_size=intermediate_size,
out_channels=hidden_size,
hidden_act=hidden_act,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
compute_type=compute_type)
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
# self-attention with ln, res
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length)
# cross-attention with ln, res
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask,
seq_length, enc_seq_length)
# feed forward with ln, res
output = self.feedforward(attention_output)
return output
class TransformerDecoder(nn.Cell):
"""
Multi-layer transformer decoder.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
seq_length (int): Length of input sequence.
enc_seq_length (int): Length of source sentences.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 16.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
"""
def __init__(self,
batch_size,
hidden_size,
num_hidden_layers,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.1,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.1,
hidden_act="relu",
compute_type=mstype.float32):
super(TransformerDecoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
layers = []
for _ in range(num_hidden_layers):
layer = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
self.reshape = P.Reshape()
self.shape = (-1, hidden_size)
self.hidden_size = hidden_size
self.batch_size = batch_size
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length):
"""Apply decoder."""
out_shape = (self.batch_size, seq_length, self.hidden_size)
prev_output = self.reshape(input_tensor, self.shape)
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask,
seq_length, enc_seq_length)
prev_output = layer_output
prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, out_shape)
return output
class CreateAttentionMaskFromInputMask(nn.Cell):
"""
Create attention mask according to input mask.
Args:
config (:class:`TransformerConfig`): Configuration for Transformer.
"""
def __init__(self):
super(CreateAttentionMaskFromInputMask, self).__init__()
self.cast = P.Cast()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.batch_matmul = P.BatchMatMul()
def construct(self, input_mask):
"""Create attention mask according to input mask."""
input_shape = self.shape(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
shape_left = input_shape + (1,)
input_mask = self.cast(input_mask, mstype.float32)
mask_left = self.reshape(input_mask, shape_left)
mask_right = self.reshape(input_mask, shape_right)
attention_mask = self.batch_matmul(mask_left, mask_right)
return attention_mask
class PredLogProbs(nn.Cell):
"""
Get log probs.
Args:
batch_size (int): Batch size.
seq_length (int): Length of input sequence.
width (int): Hidden size.
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
dtype (:class:`mindspore.dtype`): Compute type to compute log_softmax. Default: mstype.float32.
"""
def __init__(self,
batch_size,
width,
compute_type=mstype.float32,
dtype=mstype.float32):
super(PredLogProbs, self).__init__()
self.batch_size = batch_size
self.width = width
self.compute_type = compute_type
self.dtype = dtype
self.reshape = P.Reshape()
self.matmul = P.MatMul(transpose_b=True)
self.log_softmax = nn.LogSoftmax(axis=-1)
self.cast = P.Cast()
def construct(self,
input_tensor,
output_weights,
seq_length):
"""Get log probs."""
shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width)
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor)
input_tensor = self.cast(input_tensor, self.compute_type)
output_weights = self.cast(output_weights, self.compute_type)
logits = self.matmul(input_tensor, output_weights)
logits = self.cast(logits, self.dtype)
log_probs = self.log_softmax(logits)
return log_probs
class TransformerDecoderStep(nn.Cell):
"""
Multi-layer transformer decoder step.
Args:
batch_size (int): Batch size of input dataset.
hidden_size (int): Size of the encoder layers.
max_decode_length (int): Max decode length.
enc_seq_length (int): Length of source sentences.
num_hidden_layers (int): Number of hidden layers in encoder cells.
num_attention_heads (int): Number of attention heads in encoder cells. Default: 16.
intermediate_size (int): Size of intermediate layer in encoder cells. Default: 4096.
attention_probs_dropout_prob (float): The dropout probability for
SelfAttention. Default: 0.1.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
hidden_act (str): Activation function used in the encoder cells. Default: "gelu".
compute_type (:class:`mindspore.dtype`): Compute type. Default: mstype.float32.
embedding_lookup (:class:`EmbeddingLookup`): Embedding lookup module.
embedding_processor (:class:`EmbeddingPostprocessor`) Embedding postprocessor module.
projection (:class:`PredLogProbs`): PredLogProbs module
"""
def __init__(self,
batch_size,
hidden_size,
max_decode_length,
num_hidden_layers,
num_attention_heads=16,
intermediate_size=4096,
attention_probs_dropout_prob=0.3,
use_one_hot_embeddings=False,
initializer_range=0.02,
hidden_dropout_prob=0.3,
hidden_act="relu",
compute_type=mstype.float32,
embedding_lookup=None,
embedding_processor=None,
projection=None):
super(TransformerDecoderStep, self).__init__(auto_prefix=False)
self.num_hidden_layers = num_hidden_layers
self.tfm_embedding_lookup = embedding_lookup
self.tfm_embedding_processor = embedding_processor
self.projection = projection
self.tfm_decoder = TransformerDecoder(
batch_size=batch_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.ones_like = P.OnesLike()
self.shape = P.Shape()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
self.expand = P.ExpandDims()
self.multiply = P.Mul()
ones = np.ones(shape=(max_decode_length, max_decode_length))
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32)
self.cast_compute_type = CastWrapper(dst_type=compute_type)
def construct(self, input_ids, enc_states, enc_attention_mask, seq_length):
"""
Multi-layer transformer decoder step.
input_ids: [batch_size * beam_width]
"""
# process embedding
input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids)
input_embedding = self.tfm_embedding_processor(input_embedding)
input_embedding = self.cast_compute_type(input_embedding)
input_shape = self.shape(input_ids)
input_len = input_shape[1]
future_mask = self.future_mask[0:input_len:1, 0:input_len:1]
input_mask = self.ones_like(input_ids)
input_mask = self._create_attention_mask_from_input_mask(input_mask)
input_mask = self.multiply(input_mask, self.expand(future_mask, 0))
input_mask = self.cast_compute_type(input_mask)
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::]
# call TransformerDecoder
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length)
# take the last step
decoder_output = decoder_output[::, input_len-1:input_len:1, ::]
# projection and log_prob
log_probs = self.projection(decoder_output, embedding_tables, 1)
return log_probs
@constexpr
def convert_np_to_tensor_encoder(seq_length):
ones = np.ones(shape=(seq_length, seq_length))
return Tensor(np.tril(ones), dtype=mstype.float32)
class TransformerModel(nn.Cell):
"""
Transformer with encoder and decoder.
Args:
config (Class): Configuration for Transformer.
is_training (bool): True for training mode. False for eval mode.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
"""
def __init__(self,
config,
is_training,
use_one_hot_embeddings=False):
super(TransformerModel, self).__init__()
config = copy.deepcopy(config)
self.is_training = is_training
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
self.batch_size = config.batch_size
self.hidden_size = config.hidden_size
self.num_hidden_layers = config.num_hidden_layers
self.embedding_size = config.hidden_size
self.last_idx = self.num_hidden_layers - 1
self.beam_width = config.beam_width
self.max_decode_length = config.max_decode_length
self.tfm_embedding_lookup = EmbeddingLookup(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=0.02,
max_position_embeddings=config.max_position_embeddings,
dropout_prob=config.hidden_dropout_prob)
self.tfm_encoder = TransformerEncoder(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
compute_type=config.compute_type)
if is_training:
self.projection = PredLogProbs(
batch_size=self.batch_size,
width=self.hidden_size,
compute_type=config.compute_type,
dtype=config.dtype)
self.tfm_decoder = TransformerDecoder(
batch_size=self.batch_size,
hidden_size=self.hidden_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
compute_type=config.compute_type)
else:
self.projection = PredLogProbs(
batch_size=self.batch_size * config.beam_width,
width=self.hidden_size,
compute_type=config.compute_type,
dtype=config.dtype)
self.tfm_decoder = TransformerDecoderStep(
batch_size=self.batch_size * config.beam_width,
hidden_size=self.hidden_size,
max_decode_length=config.max_decode_length,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
use_one_hot_embeddings=False,
initializer_range=config.initializer_range,
hidden_dropout_prob=config.hidden_dropout_prob,
hidden_act=config.hidden_act,
compute_type=config.compute_type,
embedding_lookup=self.tfm_embedding_lookup,
embedding_processor=self.tfm_embedding_postprocessor_for_decoder,
projection=self.projection)
self.tfm_decoder = BeamSearchDecoder(
batch_size=config.batch_size,
seq_length=config.seq_length,
vocab_size=config.vocab_size,
decoder=self.tfm_decoder,
beam_width=config.beam_width,
length_penalty_weight=config.length_penalty_weight,
max_decode_length=config.max_decode_length)
self.tfm_decoder.add_flags(loop_can_unroll=True)
self.tile_beam = TileBeam(beam_width=self.beam_width)
ones = np.ones(shape=(self.batch_size, self.max_decode_length))
self.encdec_mask = Tensor(ones, mstype.float32)
self.cast = P.Cast()
self.dtype = config.dtype
self.cast_compute_type = CastWrapper(dst_type=config.compute_type)
self.expand = P.ExpandDims()
self.multiply = P.Mul()
self.shape = P.Shape()
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask()
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None):
"""Transformer with encoder and decoder."""
seq_length = self.shape(source_ids)[1]
# process source sentence
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids)
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask)
# transformer encoder
encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output),
self.cast_compute_type(enc_attention_mask),
seq_length)
if self.is_training:
future_mask = convert_np_to_tensor_encoder(seq_length)
# process target sentence
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids)
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask)
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(future_mask, 0))
# transformer decoder
decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output),
self.cast_compute_type(tgt_attention_mask),
encoder_output, enc_attention_mask,
seq_length, seq_length)
# calculate logits and log_probs
log_probs = self.projection(decoder_output, embedding_tables, seq_length)
ret = log_probs
else:
beam_encoder_output = self.tile_beam(encoder_output)
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1))
beam_enc_attention_mask = self.tile_beam(enc_attention_mask)
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask)
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask)
ret = predicted_ids
return ret
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.1

搜索帮助