Source code for src.pipelines.tensorflow_v2.models.hyper_wnet

# W-Net structure copied from the paper:
# The Little W-Net That Could: State-of-the-Art Retinal Vessel Segmentation
# with Minimalistic Models by Galdran et al.
# Link: https://arxiv.org/abs/2009.01907
from typing import Tuple

import kerastuner
from tensorflow.keras.layers import (Conv2D, MaxPool2D, Conv2DTranspose,
                                     Input, BatchNormalization, concatenate)

from src.pipelines.tensorflow_v2.losses.custom_losses import *
from src.pipelines.tensorflow_v2.models.hyper_unet_resnet import ResBlock


# *============================ Conv Bridge Block ============================*
[docs]class ConvBridgeBlock(tf.keras.layers.Layer): """ A Convolutional Bridge Block to be used in a W-Net. This is similar to the class defined in the wnet.py script, but has the addition of the kernel_size in the instantiator to allow the option to include this parameter in the search """ def __init__(self, channels, kernel_size, initializer, activation): """ Instantiates a ConvBridgeBlock object :param channels: number of filters required for the conv layers :param kernel_size: the side length of a kernel :param activation: the activation function to use :param initializer: the weight initializer to use """ super().__init__() self.conv = Conv2D(filters=channels, kernel_size=kernel_size, strides=1, padding='same', kernel_initializer=initializer) self.bn = BatchNormalization() self.activation = activation
[docs] def call(self, x: tf.Tensor) -> tf.Tensor: """ Applies a ConvBridgeBlock to the an input :param x: the input to apply the ConvBridgeBlock to :return: the output of the ConvBridgeBlock """ x1 = self.conv(x) x1 = self.activation(x1) x1 = self.bn(x1) return x1
# *================================== W-Net ==================================*
[docs]class HyperWnet(kerastuner.HyperModel): """ A Hyper W-Net model which can be tuned by kerastuner """ def __init__(self, input_shape: Tuple[int, int, int], output_channels: int): """ Instantiates a HyperWnet object. :param input_shape: the shape of the input :param output_channels: the number of output channels to use in the final layer; this is related to the number of classes in the final prediction """ self.input_shape = input_shape self.output_channels = output_channels
[docs] def build(self, hp: kerastuner.HyperParameters) -> kerastuner.HyperModel: """ A function which creates a kerastuner hyper model with a defined search space. This function signature matches the requirements of the hyperparameter tuning algorithms in kerastuner. :param hp: a HyperParameters instance :return: a HyperWnet instance """ # Search space # create the search space initializer = hp.Choice("initializer", ["he_normal", "glorot_uniform"]) activation = hp.Choice("activation", ["relu", "selu"]) filter = hp.Choice("filters", [0, 1, 2]) kernel_size = hp.Choice("kernel_size", [3]) optimizer = hp.Choice("optimizer", ["adam", "sgd"]) wnets = hp.Choice("num_wnets", [2]) loss_choice = hp.Choice("loss", ["focal", "wce", "bce"]) if optimizer == "adam": opt = tf.keras.optimizers.Adam( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')) elif optimizer == "sgd": opt = tf.keras.optimizers.SGD( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log'), momentum=hp.Float('momentum', 0.5, 0.9)) if initializer == "he_normal": initializer = tf.keras.initializers.he_normal(seed=3141) elif initializer == "glorot_uniform": initializer = tf.keras.initializers.glorot_uniform(seed=3141) if activation == "relu": activation = tf.nn.relu elif activation == "selu": activation = tf.nn.selu if loss_choice == "bce": loss = tf.keras.losses.binary_crossentropy elif loss_choice == "wce": loss = WeightedCE(hp.Float("loss_weight", 0.5, 0.99)) elif loss_choice == "focal": loss = FocalLossV2(hp.Float("loss_weight", 0.5, 0.99)) input = [] res_down1 = [] pool1 = [] res_down2 = [] pool2 = [] res_bottle = [] trans_conv1 = [] res_up1 = [] bridge1 = [] concat1 = [] trans_conv2 = [] res_up2 = [] bridge2 = [] concat2 = [] output_layer = [] input.append(Input(shape=self.input_shape)) for i in range(wnets): # Mini-Unet res_down1.append(ResBlock(8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(input[i])) pool1.append(MaxPool2D(pool_size=2, strides=2)(res_down1[i])) res_down2.append(ResBlock(16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(pool1[i])) pool2.append(MaxPool2D(pool_size=2, strides=2)(res_down2[i])) res_bottle.append(ResBlock(32 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer, decode=True)(pool2[i])) # Expanding trans_conv1.append(Conv2DTranspose(filters=16 * 2 ** filter, kernel_size=2, strides=2, padding='same')(res_bottle[i])) res_up1.append(ResBlock(16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(trans_conv1[i])) bridge1.append(ConvBridgeBlock( 16 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(res_down2[i])) concat1.append(concatenate([res_up1[i], bridge1[i]])) trans_conv2.append(Conv2DTranspose(filters=8 * 2 ** filter, kernel_size=2, strides=2, padding='same')(concat1[i])) res_up2.append(ResBlock(8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(trans_conv2[i])) bridge2.append(ConvBridgeBlock( 8 * 2 ** filter, kernel_size=kernel_size, activation=activation, initializer=initializer)(res_down1[i])) concat2.append(concatenate([res_up2[i], bridge2[i]])) # Output output_layer.append(Conv2D(self.output_channels, 1, strides=1, padding='same', activation="sigmoid")(concat2[i])) # input for next w-net input.append(concatenate([input[i], output_layer[i]])) model = tf.keras.Model(inputs=input[0], outputs=output_layer[i], name="Hyper-UNetResnet") model.compile( optimizer=opt, loss=loss, metrics=['accuracy', tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.AUC(name="auc", curve="PR"), ] ) return model
# *===========================================================================*