发布于2020-04-03 10:46 阅读(1566) 评论(0) 点赞(28) 收藏(5)
MultiHeadAttention.py
-
- import tensorflow as tf
- from tensorflow import keras
- from tensorflow.keras import layers
-
- """
- class MultiHeadAttention(keras.Model):
- # https://machinetalk.org/2019/04/29/create-the-transformer-with-tensorflow-2-0/
- def __init__(self, model_size, h, dropout):
- super(MultiHeadAttention, self).__init__()
- self.query_size = model_size // h
- self.key_size = model_size // h
- self.value_size = model_size // h
- self.h = h
- self.wq = [layers.Dense(self.query_size) for _ in range(h)]
- self.wk = [layers.Dense(self.key_size) for _ in range(h)]
- self.wv = [layers.Dense(self.value_size) for _ in range(h)]
- self.wo = layers.Dense(model_size)
- self.dropout = layers.Dropout(dropout)
- def call(self, query, value):
- # query has shape (batch, query_len, model_size)
- # value has shape (batch, value_len, model_size)
- heads = []
- for i in range(self.h):
- score = self.dropout(tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True))
- # Here we scale the score as described in the paper
- score /= tf.math.sqrt(tf.dtypes.cast(self.key_size, tf.float32))
- # score has shape (batch, query_len, value_len)
- alignment = tf.nn.softmax(score, axis=2)
- # alignment has shape (batch, query_len, value_len)
- head = tf.matmul(alignment, self.wv[i](value))
- # head has shape (batch, decoder_len, value_size)
- heads.append(head)
- # Concatenate all the attention heads
- # so that the last dimension summed up to model_size
- heads = tf.concat(heads, axis=2)
- heads = self.wo(heads)
- # heads has shape (batch, query_len, model_size)
- return heads
- """
-
- class MultiHeadAttention(keras.Model):
- # https://www.tensorflow.org/tutorials/text/transformer
- def __init__(self, d_model, num_heads, dropout):
- super(MultiHeadAttention, self).__init__()
- self.num_heads = num_heads
- self.d_model = d_model
-
- assert d_model % self.num_heads == 0
-
- self.depth = d_model // self.num_heads
-
- self.wq = layers.Dense(d_model)
- self.wk = layers.Dense(d_model)
- self.wv = layers.Dense(d_model)
- self.dropout = layers.Dropout(dropout)
-
- self.dense = layers.Dense(d_model)
-
- def scaled_dot_product_attention(self, q, k, v, mask):
- """计算注意力权重。
- q, k, v 必须具有匹配的前置维度。
- k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。
- 虽然 mask 根据其类型(填充或前瞻)有不同的形状,
- 但是 mask 必须能进行广播转换以便求和。
- 参数:
- q: 请求的形状 == (..., seq_len_q, depth)
- k: 主键的形状 == (..., seq_len_k, depth)
- v: 数值的形状 == (..., seq_len_v, depth_v)
- mask: Float 张量,其形状能转换成
- (..., seq_len_q, seq_len_k)。默认为None。
- 返回值:
- 输出,注意力权重
- """
- matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
- # 缩放 matmul_qk
- dk = tf.cast(tf.shape(k)[-1], tf.float32)
- scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
- # 将 mask 加入到缩放的张量上。
- if mask is not None:
- scaled_attention_logits += (mask * -1e9)
-
- # softmax 在最后一个轴(seq_len_k)上归一化,因此分数
- # 相加等于1。
- attention_weights = self.dropout(tf.nn.softmax(scaled_attention_logits, axis=-1)) # (..., seq_len_q, seq_len_k)
- output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
-
- return output, attention_weights
-
- def split_heads(self, x, batch_size):
- """分拆最后一个维度到 (num_heads, depth).
- 转置结果使得形状为 (batch_size, num_heads, seq_len, depth)
- """
- x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
- return tf.transpose(x, perm=[0, 2, 1, 3])
-
- def call(self, q, k, v, mask):
- batch_size = tf.shape(q)[0]
-
- q = self.wq(q) # (batch_size, seq_len, d_model)
- k = self.wk(k) # (batch_size, seq_len, d_model)
- v = self.wv(v) # (batch_size, seq_len, d_model)
-
- q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
- k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
- v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
-
- # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
- # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
- scaled_attention, attention_weights = self.scaled_dot_product_attention(
- q, k, v, mask)
-
- scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
-
- concat_attention = tf.reshape(scaled_attention,
- (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
-
- output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
-
- return output, attention_weights
rztx.py
-
- import tensorflow as tf
- from tensorflow import keras
- from tensorflow.keras import layers
- from tensorflow.keras import activations
- import tensorflow_addons as tfa
-
- from MultiHeadAttention import MultiHeadAttention
-
- class RZTXEncoderLayer(keras.Model):
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
- super(RZTXEncoderLayer,self).__init__()
- # d_model = E Q:[L,N,E] K:[S,N,E] V:[S,N,E] bs = N
- self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) # 自注意力模型,等待tensorflow更新多头
- # Implementation of Feedforward model
- self.linear1 = layers.Dense(dim_feedforward) # 线性1
- self.dropout = layers.Dropout(dropout)
- self.linear2 = layers.Dense(d_model) # 线性2
- self.dropout1 = layers.Dropout(dropout)
- self.dropout2 = layers.Dropout(dropout)
- self.resweight = tf.Variable(0.0,trainable=True) # 学习参数alpha
-
- if activation == "relu":
- self.activation = activations.relu
- elif activation == "gelu":
- self.activation = tfa.activations.gelu
-
- def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = activations.relu
- super().__setstate__(state)
-
- def call(self, src, mask=None):
- # Self attention layer
- src2 = src
- src2,_ = self.self_attn(src2, src2, src2, mask) # [l,bs,emb]
- src2 = src2 * self.resweight
- src = src + self.dropout1(src2) # [l,bs,emb]
-
- # Pointiwse FF Layer 全连接层
- src2 = src
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
- src2 = src2 * self.resweight
- src = src + self.dropout2(src2)
- return src
-
- class RZTXDecoderLayer(keras.Model):
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
- super(RZTXDecoderLayer,self).__init__()
- self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
- self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
- # Implementation of Feedforward model
- self.linear1 = layers.Dense(dim_feedforward)
- self.dropout = layers.Dropout(dropout)
- self.linear2 = layers.Dense(d_model)
-
- self.dropout1 = layers.Dropout(dropout)
- self.dropout2 = layers.Dropout(dropout)
- self.dropout3 = layers.Dropout(dropout)
- self.resweight = tf.Variable(0.0,trainable=True)
-
- if activation == "relu":
- self.activation = activations.relu
- elif activation == "gelu":
- self.activation = tfa.activations.gelu
-
- def call(self, tgt, memory, tgt_mask=None, memory_mask=None):
-
- tgt2,_ = self.self_attn(tgt, tgt, tgt, tgt_mask)
- tgt = tgt + self.dropout1(tgt2) * self.resweight
- # Q = tgt; K = memory; V = memory
- tgt2,_ = self.multihead_attn(tgt, memory, memory, memory_mask)
- tgt = tgt + self.dropout2(tgt2) * self.resweight
-
- if hasattr(self, "activation"):
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
- else: # for backward compatibility
- tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
- tgt = tgt + self.dropout3(tgt2) * self.resweight
- return tgt
-
- """
- encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
- src = tf.random.normal([32, 10, 512]) # [bs,q,emb]
- out = encoder_layer(src)
- print(out.shape)
- decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
- memory = tf.random.normal([32, 10, 512])
- tgt = tf.random.normal([32, 20, 512])
- out = decoder_layer(tgt, memory)
- print(out.shape)
- """
原文链接:https://blog.csdn.net/coolsunxu/article/details/105266057
作者:以拯救苍生己任
链接:https://www.pythonheidong.com/blog/article/301755/79a6d8b77ba224d8b524/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!