Source code for src.pipelines.tensorflow_v2.models.hyper_unet

from typing import Tuple

import kerastuner
from kerastuner import HyperModel
from tensorflow.keras.layers import (Input, Conv2D, MaxPool2D,
                                     Conv2DTranspose, concatenate)

from src.pipelines.tensorflow_v2.losses.custom_losses import *


# *========================== U-Net Building Block ===========================*
[docs]class HyperUnet(HyperModel): """ A Hyper U-Net model which can be tuned by kerastuner """ def __init__(self, input_shape: Tuple[int, int, int], output_channels: int): """ Instantiates a HyperUnet object. :param input_shape: the shape of the input :param output_channels: the number of output channels to use in the final layer; this is related to the number of classes in the final prediction """ self.input_shape = input_shape self.output_channels = output_channels
[docs] def build(self, hp: kerastuner.HyperParameters) -> kerastuner.HyperModel: """ A function which creates a kerastuner hyper model with a defined search space. This function signature matches the requirements of the hyperparameter tuning algorithms in kerastuner. :param hp: a HyperParameters instance :return: a HyperUnet instance """ # create the search space initializer = hp.Choice("initializer", ["he_normal", "glorot_uniform"]) activation = hp.Choice("activation", ["relu", "selu"]) filter = hp.Choice("filters", [0, 1, 2, 3]) kernel_size = hp.Choice("kernel_size", [3]) optimizer = hp.Choice("optimizer", ["adam", "sgd"]) padding = "same" loss_choice = hp.Choice("loss", ["focal", "wce", "bce"]) if optimizer == "adam": opt = tf.keras.optimizers.Adam( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')) elif optimizer == "sgd": opt = tf.keras.optimizers.SGD( hp.Float('learning_rate', 1e-4, 1e-2, sampling='log'), momentum=hp.Float('momentum', 0.5, 0.9)) if initializer == "he_normal": initializer = tf.keras.initializers.he_normal(seed=3141) elif initializer == "glorot_uniform": initializer = tf.keras.initializers.glorot_uniform(seed=3141) if activation == "relu": activation = tf.nn.relu elif activation == "selu": activation = tf.nn.selu if loss_choice == "bce": loss = tf.keras.losses.binary_crossentropy elif loss_choice == "wce": loss = WeightedCE(hp.Float("loss_weight", 0.5, 0.99)) elif loss_choice == "focal": loss = FocalLossV2(hp.Float("loss_weight", 0.5, 0.99)) input = Input(shape=self.input_shape) # Down 1 conv1 = Conv2D(filters=8 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(input) conv1 = Conv2D(filters=8 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(conv1) # Half the x,y dimensions pool1 = MaxPool2D(pool_size=(2, 2), strides=2)(conv1) conv2 = Conv2D(filters=16 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(pool1) conv2 = Conv2D(filters=16 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(conv2) pool2 = MaxPool2D(pool_size=(2, 2), strides=2)(conv2) conv3 = Conv2D(filters=32 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(pool2) conv3 = Conv2D(filters=32 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(conv3) pool3 = MaxPool2D(pool_size=(2, 2), strides=2)(conv3) conv4 = Conv2D(filters=64 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(pool3) conv4 = Conv2D(filters=64 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(conv4) # bottleneck pool4 = MaxPool2D(pool_size=(2, 2), strides=2)(conv4) conv5 = Conv2D(filters=128 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(pool4) conv5 = Conv2D(filters=128 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(conv5) # Upblock uptrans1 = Conv2DTranspose(filters=64 * 2 ** filter, kernel_size=2, strides=2, padding=padding, kernel_initializer=initializer)(conv5) skip1 = concatenate([conv4, uptrans1]) up1 = Conv2D(filters=64 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(skip1) up1 = Conv2D(filters=64 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(up1) uptrans2 = Conv2DTranspose(filters=32 * 2 ** filter, kernel_size=2, strides=2, padding=padding, kernel_initializer=initializer)(up1) skip2 = concatenate([conv3, uptrans2]) up2 = Conv2D(filters=32 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(skip2) up2 = Conv2D(filters=32 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(up2) uptrans3 = Conv2DTranspose(filters=16 * 2 ** filter, kernel_size=2, strides=2, padding=padding, kernel_initializer=initializer)(up2) skip3 = concatenate([conv2, uptrans3]) up3 = Conv2D(filters=16 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(skip3) up3 = Conv2D(filters=16 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(up3) uptrans4 = Conv2DTranspose(filters=8 * 2 ** filter, kernel_size=2, strides=2, padding=padding, kernel_initializer=initializer)(up3) skip4 = concatenate([conv1, uptrans4]) up4 = Conv2D(filters=8 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(skip4) up4 = Conv2D(filters=8 * 2 ** filter, kernel_size=kernel_size, padding=padding, activation=activation, kernel_initializer=initializer)(up4) output_layer = Conv2D(self.output_channels, 1, strides=1, padding=padding, activation="sigmoid", name="classification_layer")(up4) model = tf.keras.Model(inputs=input, outputs=output_layer, name="Hyper-UNet") model.compile( optimizer=opt, loss=loss, metrics=['accuracy', tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.AUC(name="auc", curve="PR")] ) return model