1 Star 5 Fork 3

多读书多看报少玩游戏少睡觉/Easy SmartHome

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
transformer-checkpoint.ipynb 4.60 KB
一键复制 编辑 原始数据 按行查看 历史
import tensorflow as tf
  • bert encoder
def encoder(inputs):
    Q = tf.keras.layers.Dense(64, name = 'Queries')(inputs)
    K = tf.keras.layers.Dense(64, name = 'Key')(inputs)
    V = tf.keras.layers.Dense(64, name = 'Values')(inputs)
    score = tf.keras.layers.Softmax(tf.keras.layers.LayerNormalization(tf.multiply(Q,K)))
    Z = tf.multiply(score,V)
    encoder_output = tf.keras.layers.Dense(128)(tf.keras.layers.LayerNormalization(Z))
    return encoder_output
  • bert position information
class PositionEmbedding(tf.keras.layers.Layer):
    def __init__(
        self,
        input_dim,
        output_dim,
        merge_mode='add',
        hierarchical=None,
        embeddings_initializer='zeros',
        custom_position_ids=False,
        **kwargs
    ):
        super(PositionEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.merge_mode = merge_mode
        self.hierarchical = hierarchical
        self.embeddings_initializer = initializers.get(embeddings_initializer)
        self.custom_position_ids = custom_position_ids

    def build(self, input_shape):
        super(PositionEmbedding, self).build(input_shape)
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.input_dim, self.output_dim),
            initializer=self.embeddings_initializer
        )
        
    def call(self, inputs):
        """如果custom_position_ids,那么第二个输入为自定义的位置id
        """
        if self.custom_position_ids:
            inputs, position_ids = inputs
            if 'int' not in K.dtype(position_ids):
                position_ids = K.cast(position_ids, 'int32')
        else:
            input_shape = K.shape(inputs)
            batch_size, seq_len = input_shape[0], input_shape[1]
            position_ids = K.arange(0, seq_len, dtype='int32')[None]

        if self.hierarchical:
            alpha = 0.4 if self.hierarchical is True else self.hierarchical
            embeddings = self.embeddings - alpha * self.embeddings[:1]
            embeddings = embeddings / (1 - alpha)
            embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
            embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
            embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
        else:
            if self.custom_position_ids:
                embeddings = K.gather(self.embeddings, position_ids)
            else:
                embeddings = self.embeddings[None, :seq_len]

        if self.merge_mode == 'add':
            return inputs + embeddings
        elif self.merge_mode == 'mul':
            return inputs * (embeddings + 1.0)
        elif self.merge_mode == 'zero':
            return embeddings
        else:
            if not self.custom_position_ids:
                embeddings = K.tile(embeddings, [batch_size, 1, 1])
            return K.concatenate([inputs, embeddings])
tf.keras.
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xulucky/easy-smarthome.git
git@gitee.com:xulucky/easy-smarthome.git
xulucky
easy-smarthome
Easy SmartHome
master

搜索帮助