Source code for src.pipelines.tensorflow_v2.models.hyper_unet_resnet

from typing import Tuple

import kerastuner
from kerastuner import HyperModel
from tensorflow.keras.layers import (Input, Conv2D, MaxPool2D,
                                     Conv2DTranspose, concatenate,
                                     BatchNormalization)

from src.pipelines.tensorflow_v2.losses.custom_losses import *


# *============================= Residual Block ==============================*
[docs]class ResBlock(tf.keras.layers.Layer): """ A Residual Block. This is similar to the class defined in the unet_resnet.py script, but has the addition of the kernel_size in the instantiator to allow the option to include this parameter in the search space. """ def __init__(self, channels: int, kernel_size: int, activation: str, initializer: str, stride: int = 1, decode: bool = False): """ Instantiates a ResBlock object. :param channels: number of filters required for the conv layers :param kernel_size: the side length of a kernel :param activation: the activation function to use :param initializer: the weight initializer to use :param stride: the stride length :param decode: whether this ResBlock is being used in expansive path of the U-Net """ super().__init__() self.activation = activation # for dotted shortcut, i.e stride = 2 self.flag = (stride != 1) # for concatenated feature map self.decode = decode # res block self.conv1 = Conv2D(filters=channels, kernel_size=kernel_size, strides=stride, padding='same', kernel_initializer=initializer) self.bn1 = BatchNormalization() self.conv2 = Conv2D(filters=channels, kernel_size=kernel_size, strides=1, padding='same', kernel_initializer=initializer) self.bn2 = BatchNormalization() if self.decode or self.flag: # 1x1 convolution # using for input that's been activated already self.conv3 = Conv2D(channels, 1, stride, kernel_initializer=initializer)
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Applies a ResBlock to the an input :param x: the input to apply the ResBlock to :return: the output of the ResBlock """ x1 = self.conv1(x) x1 = self.activation(x1) x1 = self.bn1(x1) x1 = self.conv2(x1) x1 = self.activation(x1) # Matching dims (i.e. projection shortcut) if self.flag or self.decode: x = self.conv3(x) x = self.activation(x) x1 += x x1 = self.bn2(x1) # Addition is before the activation... return x1
[docs]class HyperUnetResnet(HyperModel): """ A Hyper U-Net ResNet34 model which can be tuned by kerastuner """ def __init__(self, input_shape: Tuple[int, int, int], output_channels: int): """ Instantiates a HyperUnetResnet object. :param input_shape: the shape of the input :param output_channels: the number of output channels to use in the final layer; this is related to the number of classes in the final prediction """ self.input_shape = input_shape self.output_channels = output_channels
[docs] def build(self, hp: kerastuner.HyperParameters) -> kerastuner.HyperModel: """ A function which creates a kerastuner hyper model with a defined search space. This function signature matches the requirements of the hyperparameter tuning algorithms in kerastuner. :param hp: a HyperParameters instance :return: a HyperUnetResnet instance """ # Search space # create the search space initializer = hp.Choice("initializer", ["he_normal", "glorot_uniform"]) activation = hp.Choice("activation", ["relu", "selu"]) filter = hp.Choice("filters", [0, 1, 2, 3]) kernel_size = hp.Choice("kernel_size", [3]) optimizer = hp.Choice("optimizer", ["adam", "sgd"]) loss_choice = hp.Choice("loss", ["focal", "wce", "bce"]) if optimizer == "adam": opt = tf.keras.optimizers.Adam( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')) elif optimizer == "sgd": opt = tf.keras.optimizers.SGD( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log'), momentum=hp.Float('momentum', 0.5, 0.9)) if initializer == "he_normal": initializer = tf.keras.initializers.he_normal(seed=3141) elif initializer == "glorot_uniform": initializer = tf.keras.initializers.glorot_uniform(seed=3141) if activation == "relu": activation = tf.nn.relu elif activation == "selu": activation = tf.nn.selu if loss_choice == "bce": loss = tf.keras.losses.binary_crossentropy elif loss_choice == "wce": loss = WeightedCE(hp.Float("loss_weight", 0.5, 0.99)) elif loss_choice == "focal": loss = FocalLossV2(hp.Float("loss_weight", 0.5, 0.99)) input = Input(shape=self.input_shape) # Contracting # valid padding since down sampling down_conv1 = Conv2D(filters=8 * 2 ** filter, kernel_size=7, strides=2, padding='same')(input) down1_activated = activation(down_conv1) down1_bn = BatchNormalization()(down1_activated) # Changed strides to pool size to 2 mp1 = MaxPool2D(pool_size=2, strides=2)(down1_bn) down_block_2_1 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(mp1) down_block_2_2 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_2_1) down_block_2_3 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_2_2) down_block_3_1 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, stride=2)(down_block_2_3) down_block_3_2 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_3_1) down_block_3_3 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_3_2) down_block_3_4 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_3_3) down_block_4_1 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, stride=2)(down_block_3_4) down_block_4_2 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_4_1) down_block_4_3 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_4_2) down_block_4_4 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_4_3) down_block_4_5 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_4_4) down_block_4_6 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_4_5) down_block_5_1 = ResBlock(channels=64 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, stride=2)(down_block_4_6) down_block_5_2 = ResBlock(channels=64 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_5_1) down_block_5_3 = ResBlock(channels=64 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(down_block_5_2) conv_up1 = Conv2DTranspose(filters=32 * 2 ** filter, kernel_size=2, strides=2, kernel_initializer=initializer)( down_block_5_3) # default axis is -1 => the filter axis conv_up_concat_1 = concatenate([conv_up1, down_block_4_6]) up_block_1_1 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(conv_up_concat_1) up_block_1_2 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_1_1) up_block_1_3 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_1_2) up_block_1_4 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_1_3) up_block_1_5 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_1_4) up_block_1_6 = ResBlock(channels=32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_1_5) # Layer 2 conv_up2 = Conv2DTranspose(filters=16 * 2 ** filter, kernel_size=2, strides=2, kernel_initializer=initializer)( up_block_1_6) conv_up_concat_2 = concatenate([conv_up2, down_block_3_4]) up_block_2_1 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(conv_up_concat_2) up_block_2_2 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_2_1) up_block_2_3 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_2_2) up_block_2_4 = ResBlock(channels=16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_2_3) # Layer 3 conv_up3 = Conv2DTranspose(filters=8 * 2 ** filter, kernel_size=2, strides=2, kernel_initializer=initializer)( up_block_2_4) conv_up_concat_3 = concatenate([conv_up3, down_block_2_3]) up_block_3_1 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(conv_up_concat_3) up_block_3_2 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_3_1) up_block_3_3 = ResBlock(channels=8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(up_block_3_2) # Layer 4 conv_up4 = Conv2DTranspose(filters=8 * 2 ** filter, kernel_size=2, strides=2, kernel_initializer=initializer)( up_block_3_3) conv_up_concat_4 = concatenate([conv_up4, down1_activated]) # need same padding so that output image size is the same as mask up_conv4 = Conv2D(filters=8 * 2 ** filter, kernel_size=7, strides=1, kernel_initializer=initializer, padding="same")(conv_up_concat_4) up4_activated = activation(up_conv4) up4_bn = BatchNormalization()(up4_activated) # Think about whether this needs to be activated conv_up5 = Conv2DTranspose(filters=8 * 2 ** filter, kernel_size=2, strides=2, kernel_initializer=initializer)( up4_bn) output_layer = Conv2D(self.output_channels, 1, strides=1, padding='same', activation="sigmoid")(conv_up5) model = tf.keras.Model(inputs=input, outputs=output_layer, name="Hyper-UNetResnet") model.compile( optimizer=opt, loss=loss, metrics=['accuracy', tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.AUC(name="auc", curve="PR")] ) return model