程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

暂无数据

Tensorflow2 Repo for "ReZero is All You Need: Fast Convergence at Large Depth"

发布于2020-04-03 10:46     阅读(1541)     评论(0)     点赞(28)     收藏(5)


MultiHeadAttention.py

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers
  4. """
  5. class MultiHeadAttention(keras.Model):
  6. # https://machinetalk.org/2019/04/29/create-the-transformer-with-tensorflow-2-0/
  7. def __init__(self, model_size, h, dropout):
  8. super(MultiHeadAttention, self).__init__()
  9. self.query_size = model_size // h
  10. self.key_size = model_size // h
  11. self.value_size = model_size // h
  12. self.h = h
  13. self.wq = [layers.Dense(self.query_size) for _ in range(h)]
  14. self.wk = [layers.Dense(self.key_size) for _ in range(h)]
  15. self.wv = [layers.Dense(self.value_size) for _ in range(h)]
  16. self.wo = layers.Dense(model_size)
  17. self.dropout = layers.Dropout(dropout)
  18. def call(self, query, value):
  19. # query has shape (batch, query_len, model_size)
  20. # value has shape (batch, value_len, model_size)
  21. heads = []
  22. for i in range(self.h):
  23. score = self.dropout(tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True))
  24. # Here we scale the score as described in the paper
  25. score /= tf.math.sqrt(tf.dtypes.cast(self.key_size, tf.float32))
  26. # score has shape (batch, query_len, value_len)
  27. alignment = tf.nn.softmax(score, axis=2)
  28. # alignment has shape (batch, query_len, value_len)
  29. head = tf.matmul(alignment, self.wv[i](value))
  30. # head has shape (batch, decoder_len, value_size)
  31. heads.append(head)
  32. # Concatenate all the attention heads
  33. # so that the last dimension summed up to model_size
  34. heads = tf.concat(heads, axis=2)
  35. heads = self.wo(heads)
  36. # heads has shape (batch, query_len, model_size)
  37. return heads
  38. """
  39. class MultiHeadAttention(keras.Model):
  40. # https://www.tensorflow.org/tutorials/text/transformer
  41. def __init__(self, d_model, num_heads, dropout):
  42. super(MultiHeadAttention, self).__init__()
  43. self.num_heads = num_heads
  44. self.d_model = d_model
  45. assert d_model % self.num_heads == 0
  46. self.depth = d_model // self.num_heads
  47. self.wq = layers.Dense(d_model)
  48. self.wk = layers.Dense(d_model)
  49. self.wv = layers.Dense(d_model)
  50. self.dropout = layers.Dropout(dropout)
  51. self.dense = layers.Dense(d_model)
  52. def scaled_dot_product_attention(self, q, k, v, mask):
  53. """计算注意力权重。
  54. q, k, v 必须具有匹配的前置维度。
  55. k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。
  56. 虽然 mask 根据其类型(填充或前瞻)有不同的形状,
  57. 但是 mask 必须能进行广播转换以便求和。
  58. 参数:
  59. q: 请求的形状 == (..., seq_len_q, depth)
  60. k: 主键的形状 == (..., seq_len_k, depth)
  61. v: 数值的形状 == (..., seq_len_v, depth_v)
  62. mask: Float 张量,其形状能转换成
  63. (..., seq_len_q, seq_len_k)。默认为None。
  64. 返回值:
  65. 输出,注意力权重
  66. """
  67. matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
  68. # 缩放 matmul_qk
  69. dk = tf.cast(tf.shape(k)[-1], tf.float32)
  70. scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
  71. # 将 mask 加入到缩放的张量上。
  72. if mask is not None:
  73. scaled_attention_logits += (mask * -1e9)
  74. # softmax 在最后一个轴(seq_len_k)上归一化,因此分数
  75. # 相加等于1。
  76. attention_weights = self.dropout(tf.nn.softmax(scaled_attention_logits, axis=-1)) # (..., seq_len_q, seq_len_k)
  77. output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
  78. return output, attention_weights
  79. def split_heads(self, x, batch_size):
  80. """分拆最后一个维度到 (num_heads, depth).
  81. 转置结果使得形状为 (batch_size, num_heads, seq_len, depth)
  82. """
  83. x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
  84. return tf.transpose(x, perm=[0, 2, 1, 3])
  85. def call(self, q, k, v, mask):
  86. batch_size = tf.shape(q)[0]
  87. q = self.wq(q) # (batch_size, seq_len, d_model)
  88. k = self.wk(k) # (batch_size, seq_len, d_model)
  89. v = self.wv(v) # (batch_size, seq_len, d_model)
  90. q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
  91. k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
  92. v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
  93. # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
  94. # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
  95. scaled_attention, attention_weights = self.scaled_dot_product_attention(
  96. q, k, v, mask)
  97. scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
  98. concat_attention = tf.reshape(scaled_attention,
  99. (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
  100. output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
  101. return output, attention_weights

rztx.py

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. from tensorflow.keras import layers
  4. from tensorflow.keras import activations
  5. import tensorflow_addons as tfa
  6. from MultiHeadAttention import MultiHeadAttention
  7. class RZTXEncoderLayer(keras.Model):
  8. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
  9. super(RZTXEncoderLayer,self).__init__()
  10. # d_model = E Q:[L,N,E] K:[S,N,E] V:[S,N,E] bs = N
  11. self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) # 自注意力模型,等待tensorflow更新多头
  12. # Implementation of Feedforward model
  13. self.linear1 = layers.Dense(dim_feedforward) # 线性1
  14. self.dropout = layers.Dropout(dropout)
  15. self.linear2 = layers.Dense(d_model) # 线性2
  16. self.dropout1 = layers.Dropout(dropout)
  17. self.dropout2 = layers.Dropout(dropout)
  18. self.resweight = tf.Variable(0.0,trainable=True) # 学习参数alpha
  19. if activation == "relu":
  20. self.activation = activations.relu
  21. elif activation == "gelu":
  22. self.activation = tfa.activations.gelu
  23. def __setstate__(self, state):
  24. if 'activation' not in state:
  25. state['activation'] = activations.relu
  26. super().__setstate__(state)
  27. def call(self, src, mask=None):
  28. # Self attention layer
  29. src2 = src
  30. src2,_ = self.self_attn(src2, src2, src2, mask) # [l,bs,emb]
  31. src2 = src2 * self.resweight
  32. src = src + self.dropout1(src2) # [l,bs,emb]
  33. # Pointiwse FF Layer 全连接层
  34. src2 = src
  35. src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
  36. src2 = src2 * self.resweight
  37. src = src + self.dropout2(src2)
  38. return src
  39. class RZTXDecoderLayer(keras.Model):
  40. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
  41. super(RZTXDecoderLayer,self).__init__()
  42. self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
  43. self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
  44. # Implementation of Feedforward model
  45. self.linear1 = layers.Dense(dim_feedforward)
  46. self.dropout = layers.Dropout(dropout)
  47. self.linear2 = layers.Dense(d_model)
  48. self.dropout1 = layers.Dropout(dropout)
  49. self.dropout2 = layers.Dropout(dropout)
  50. self.dropout3 = layers.Dropout(dropout)
  51. self.resweight = tf.Variable(0.0,trainable=True)
  52. if activation == "relu":
  53. self.activation = activations.relu
  54. elif activation == "gelu":
  55. self.activation = tfa.activations.gelu
  56. def call(self, tgt, memory, tgt_mask=None, memory_mask=None):
  57. tgt2,_ = self.self_attn(tgt, tgt, tgt, tgt_mask)
  58. tgt = tgt + self.dropout1(tgt2) * self.resweight
  59. # Q = tgt; K = memory; V = memory
  60. tgt2,_ = self.multihead_attn(tgt, memory, memory, memory_mask)
  61. tgt = tgt + self.dropout2(tgt2) * self.resweight
  62. if hasattr(self, "activation"):
  63. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  64. else: # for backward compatibility
  65. tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
  66. tgt = tgt + self.dropout3(tgt2) * self.resweight
  67. return tgt
  68. """
  69. encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
  70. src = tf.random.normal([32, 10, 512]) # [bs,q,emb]
  71. out = encoder_layer(src)
  72. print(out.shape)
  73. decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
  74. memory = tf.random.normal([32, 10, 512])
  75. tgt = tf.random.normal([32, 20, 512])
  76. out = decoder_layer(tgt, memory)
  77. print(out.shape)
  78. """

 

原文链接:https://blog.csdn.net/coolsunxu/article/details/105266057



所属网站分类: 技术文章 > 博客

作者:以拯救苍生己任

链接:https://www.pythonheidong.com/blog/article/301755/79a6d8b77ba224d8b524/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

28 0
收藏该文
已收藏

评论内容:(最多支持255个字符)