52matlab技术网站,matlab教程,matlab安装教程,matlab下载

标题: CBAM通道注意力的是tensorflow2.xs [打印本页]

作者: matlab的旋律    时间: 2024-9-30 04:23
标题: CBAM通道注意力的是tensorflow2.xs
CBAM通道注意力的是tensorflow2.x,  参考https://blog.csdn.net/weixin_39122088/article/details/10719197
from keras import layers, regularizers
  1. #实现方式1
  2. class ChannelAttention(layers.Layer):
  3.         def __init__(self, in_planes, ratio=8):
  4.              super(ChannelAttention, self).__init__()

  5.             self.avg_out= layers.GlobalAveragePooling2D()
  6.             self.max_out= layers.GlobalMaxPooling2D()

  7.             self.fc1 = layers.Dense(in_planes//ratio, kernel_initializer='he_normal',
  8.             kernel_regularizer=regularizers.l2(5e-4),
  9.             activation=tf.nn.relu,
  10.            use_bias=True, bias_initializer='zeros')
  11.            self.fc2 = layers.Dense(in_planes, kernel_initializer='he_normal',
  12.            kernel_regularizer=regularizers.l2(5e-4),
  13.            use_bias=True, bias_initializer='zeros')

  14. def call(self, inputs):
  15.      avg_out = self.avg_out(inputs)
  16.      max_out = self.max_out(inputs)
  17.      out = tf.stack([avg_out, max_out], axis=1) # shape=(None, 2, fea_num)
  18.      out = self.fc2(self.fc1(out))
  19.      out = tf.reduce_sum(out, axis=1) # shape=(256, 512)
  20.      out = tf.nn.sigmoid(out)
  21.      out = layers.Reshape((1, 1, out.shape[1]))(out)

  22.    return out
复制代码

  1. #实现方式2
  2. class ChannelAttention(layers.Layer):
  3.      def __init__(self, in_planes):
  4.          super(ChannelAttention, self).__init__()

  5.          self.avg= layers.GlobalAveragePooling2D()
  6.          self.max= layers.GlobalMaxPooling2D()

  7.          self.fc1 = layers.Dense(in_planes//16, kernel_initializer='he_normal', activation='relu',
  8.          use_bias=True, bias_initializer='zeros')
  9.          self.fc2 = layers.Dense(in_planes, kernel_initializer='he_normal', use_bias=True,
  10.          bias_initializer='zeros')

  11.    def call(self, inputs):
  12.         avg_out = self.fc2(self.fc1(self.avg(inputs)))
  13.         max_out = self.fc2(self.fc1(self.max(inputs)))
  14.         out = avg_out + max_out
  15.         out = tf.nn.sigmoid(out)
  16.         out = tf.reshape(out, [out.shape[0], 1, 1, out.shape[1]])
  17.         out = tf.tile(out, [1, inputs.shape[1], inputs.shape[2], 1])

  18.       return out
复制代码






欢迎光临 52matlab技术网站,matlab教程,matlab安装教程,matlab下载 (http://test.52matlab.com/) Powered by Discuz! X3.2