代码拉取完成,页面将自动刷新
import tensorflow as tf
def encoder(inputs):
Q = tf.keras.layers.Dense(64, name = 'Queries')(inputs)
K = tf.keras.layers.Dense(64, name = 'Key')(inputs)
V = tf.keras.layers.Dense(64, name = 'Values')(inputs)
score = tf.keras.layers.Softmax(tf.keras.layers.LayerNormalization(tf.multiply(Q,K)))
Z = tf.multiply(score,V)
encoder_output = tf.keras.layers.Dense(128)(tf.keras.layers.LayerNormalization(Z))
return encoder_output
class PositionEmbedding(tf.keras.layers.Layer):
def __init__(
self,
input_dim,
output_dim,
merge_mode='add',
hierarchical=None,
embeddings_initializer='zeros',
custom_position_ids=False,
**kwargs
):
super(PositionEmbedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.merge_mode = merge_mode
self.hierarchical = hierarchical
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.custom_position_ids = custom_position_ids
def build(self, input_shape):
super(PositionEmbedding, self).build(input_shape)
self.embeddings = self.add_weight(
name='embeddings',
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer
)
def call(self, inputs):
"""如果custom_position_ids,那么第二个输入为自定义的位置id
"""
if self.custom_position_ids:
inputs, position_ids = inputs
if 'int' not in K.dtype(position_ids):
position_ids = K.cast(position_ids, 'int32')
else:
input_shape = K.shape(inputs)
batch_size, seq_len = input_shape[0], input_shape[1]
position_ids = K.arange(0, seq_len, dtype='int32')[None]
if self.hierarchical:
alpha = 0.4 if self.hierarchical is True else self.hierarchical
embeddings = self.embeddings - alpha * self.embeddings[:1]
embeddings = embeddings / (1 - alpha)
embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
else:
if self.custom_position_ids:
embeddings = K.gather(self.embeddings, position_ids)
else:
embeddings = self.embeddings[None, :seq_len]
if self.merge_mode == 'add':
return inputs + embeddings
elif self.merge_mode == 'mul':
return inputs * (embeddings + 1.0)
elif self.merge_mode == 'zero':
return embeddings
else:
if not self.custom_position_ids:
embeddings = K.tile(embeddings, [batch_size, 1, 1])
return K.concatenate([inputs, embeddings])
tf.keras.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。