Source code for src.pipelines.tensorflow_v2.models.unet_resnet

# Model structure adapted from: Road Extraction by Deep Residual U-Net by
# Zhang, Liu, and Wang
# Link: https://arxiv.org/abs/1711.10684
from typing import List, Tuple

import tensorflow as tf
from tensorflow.keras.layers import (Conv2D, MaxPool2D, Concatenate,
                                     Conv2DTranspose, BatchNormalization,
                                     Input)

from src.pipelines.tensorflow_v2.helpers.train_test import _TfPnsMixin


# *============================= Residual Block ==============================*
[docs]class ResBlock(tf.keras.Model): """ A Residual block """ def __init__(self, channels, initializer, activation, stride=1, decode=False): """ Instantiates a ResBlock object. :param channels: number of filters required for the conv layers :param activation: the activation function to use :param initializer: the weight initializer to use :param stride: the stride length :param decode: whether this ResBlock is being used in expansive path of the U-Net """ super().__init__() self.activation = activation # for dotted shortcut, i.e stride = 2 self.flag = (stride != 1) # for concatenated feature map self.decode = decode # res block self.conv1 = Conv2D(filters=channels, kernel_size=3, strides=stride, padding='same', kernel_initializer=initializer) self.bn1 = BatchNormalization() self.conv2 = Conv2D(filters=channels, kernel_size=3, strides=1, padding='same', kernel_initializer=initializer) self.bn2 = BatchNormalization() if self.decode or self.flag: # 1x1 convolution # using for input that's been activated already self.conv3 = Conv2D(channels, 1, stride, kernel_initializer=initializer)
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Applies a ResBlock to the an input :param x: the input to apply the ResBlock to :return: the output of the ResBlock """ x1 = self.conv1(x) x1 = self.activation(x1) x1 = self.bn1(x1) # apply activation before bn - differs from BN paper x1 = self.conv2(x1) x1 = self.activation(x1) # Matching dims (i.e. projection shortcut) # No need to activate if self.flag or self.decode: x = self.conv3(x) x = self.activation(x) x1 += x x1 = self.bn2(x1) return x1
# *============================== ResNet U-Net ===============================*
[docs]class UnetResnet(tf.keras.Model, _TfPnsMixin): """ A U-Net model, which has a ResNet 34 backbone """ def __init__(self, output_channels: int, activation: str = "relu", initializer: str = "he_normal", filters: int = 3): """ Instantiates a U-Net ResNet 34 model :param output_channels: the number of output channels to use in the final layer; this is related to the number of classes in the final prediction :param activation: the activation to use in all blocks apart from the last :param initializer: the initialiser to use in all blocks apart from the last :param filters: the filter multiple to use """ super().__init__() if initializer == "he_normal": initializer = tf.keras.initializers.he_normal(seed=3141) else: initializer = tf.keras.initializers.glorot_uniform(seed=3141) if activation == "relu": activation = tf.nn.relu else: activation = tf.nn.selu self.activation = activation # Contracting # valid padding since down sampling self.down_conv1 = Conv2D(filters=8 * 2 ** filters, kernel_size=7, strides=2, padding='same', kernel_initializer=initializer) self.down_bn = BatchNormalization() # Changed strides to pool size to 2 self.mp1 = MaxPool2D(pool_size=2, strides=2) self.down_block_2_1 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_2_2 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_2_3 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_3_1 = ResBlock(16 * 2 ** filters, stride=2, activation=activation, initializer=initializer) self.down_block_3_2 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_3_3 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_3_4 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_4_1 = ResBlock(32 * 2 ** filters, stride=2, activation=activation, initializer=initializer) self.down_block_4_2 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_4_3 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_4_4 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_4_5 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_4_6 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_5_1 = ResBlock(64 * 2 ** filters, stride=2, activation=activation, initializer=initializer) self.down_block_5_2 = ResBlock(64 * 2 ** filters, activation=activation, initializer=initializer) self.down_block_5_3 = ResBlock(64 * 2 ** filters, activation=activation, initializer=initializer) self.conv_up1 = Conv2DTranspose(filters=32 * 2 ** filters, kernel_size=2, strides=2, activation=activation, kernel_initializer=initializer) # default axis is -1 => the filter axis self.conv_up_concat_1 = Concatenate() self.up_block_1_1 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer, decode=True) self.up_block_1_2 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_1_3 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_1_4 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_1_5 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_1_6 = ResBlock(32 * 2 ** filters, activation=activation, initializer=initializer) # Layer 2 self.conv_up2 = Conv2DTranspose(filters=16 * 2 ** filters, kernel_size=2, strides=2, activation=activation, kernel_initializer=initializer) self.conv_up_concat_2 = Concatenate() self.up_block_2_1 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer, decode=True) self.up_block_2_2 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_2_3 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_2_4 = ResBlock(16 * 2 ** filters, activation=activation, initializer=initializer) # Layer 3 self.conv_up3 = Conv2DTranspose(filters=8 * 2 ** filters, kernel_size=2, strides=2, activation=activation, kernel_initializer=initializer) self.conv_up_concat_3 = Concatenate() self.up_block_3_1 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer, decode=True) self.up_block_3_2 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer) self.up_block_3_3 = ResBlock(8 * 2 ** filters, activation=activation, initializer=initializer) # Layer 4 self.conv_up4 = Conv2DTranspose(filters=8 * 2 ** filters, kernel_size=2, strides=2, activation=activation, kernel_initializer=initializer) self.conv_up_concat_4 = Concatenate() # Activation corresponding to first layer self.up_conv4 = Conv2D(filters=8 * 2 ** filters, kernel_size=7, strides=1, padding='same', kernel_initializer=initializer) self.up_bn = BatchNormalization() # Think about whether this needs to be activated: No since activation # corresponding to this first layer is dealt with above self.conv_up5 = Conv2DTranspose(filters=8 * 2 ** filters, kernel_size=2, strides=2, activation=activation, kernel_initializer=initializer) self.output_layer = Conv2D(output_channels, 1, strides=1, padding='same', activation="sigmoid") # 64x64 -> 128x128
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Applies a U-Net ResNet34 model to the input. :param inputs: an input image :return: the output of a U-Net ResNet34 model """ x = self.down_conv1(x) x = self.activation(x) down_conv1 = self.down_bn(x) x = self.mp1(down_conv1) x = self.down_block_2_1(x) x = self.down_block_2_2(x) down_2_3 = self.down_block_2_3(x) x = self.down_block_3_1(down_2_3) x = self.down_block_3_2(x) x = self.down_block_3_3(x) down_3_4 = self.down_block_3_4(x) x = self.down_block_4_1(down_3_4) x = self.down_block_4_2(x) x = self.down_block_4_3(x) x = self.down_block_4_4(x) x = self.down_block_4_5(x) down_4_6 = self.down_block_4_6(x) x = self.down_block_5_1(down_4_6) x = self.down_block_5_2(x) x = self.down_block_5_3(x) # Expanding x = self.conv_up1(x) x = self.conv_up_concat_1([x, down_4_6]) x = self.up_block_1_1(x) x = self.up_block_1_2(x) x = self.up_block_1_3(x) x = self.up_block_1_4(x) x = self.up_block_1_5(x) x = self.up_block_1_6(x) x = self.conv_up2(x) x = self.conv_up_concat_2([x, down_3_4]) x = self.up_block_2_1(x) x = self.up_block_2_2(x) x = self.up_block_2_3(x) x = self.up_block_2_4(x) x = self.conv_up3(x) x = self.conv_up_concat_3([x, down_2_3]) x = self.up_block_3_1(x) x = self.up_block_3_2(x) x = self.up_block_3_3(x) x = self.conv_up4(x) x = self.conv_up_concat_4([x, down_conv1]) x = self.up_conv4(x) x = self.activation(x) x = self.up_bn(x) # For each layer with a stride of 2 there will be both an activation # and a transposed convolution (with no activation) x = self.conv_up5(x) x = self.output_layer(x) return x
[docs] def model(self, shape: Tuple[int, int, int] = (512, 512, 1)) -> tf.keras.Model: """ Returns a U-Net model as tf.keras.Model. This is a workaround to use the functional api, which allows the model to be viewed. :param shape: the shape of the input :return: the tf.keras.Model instantiated using the functional api """ x = Input(shape=shape) return tf.keras.Model(inputs=[x], outputs=self.call(x))
[docs] def print_all_layers(self) -> None: """ Prints all the layers in the model, including the layers in the subclasses which make up the model. This uses the model() workaround function. :return: None """ model_layers = self.model().layers print(model_layers[0]) for layer in model_layers[1:-1]: for l in layer.layers: print(l) print(model_layers[-1])
[docs] def get_all_layers(self) -> List[tf.keras.layers.Layer]: """ Returns all the layers in the model, including the layers in the subclasses which make up the model. This uses the model() workaround function. :return: a list of layers """ model_layers = self.model().layers layers = [] # layers.append(model_layers[0]) for layer in model_layers[1:-1]: for l in layer.layers: layers.append(l) layers.append(model_layers[-1]) return layers
# *===========================================================================*