from typing import List, Tuple
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Conv2DTranspose
from src.pipelines.tensorflow_v2.helpers.train_test import _TfPnsMixin
# *========================== U-Net Building Block ===========================*
[docs]class UnetBlock(tf.keras.Model):
def __init__(self, num_filters: int, kernel_size: int,
decode: bool = False, encode: bool = False,
batch_norm: bool = False, padding: str = "same",
activation="relu", initializer='he_normal',
name: str = ""):
"""
A building block to be used when constructing a U-Net model
Three possible modes for this block:
1) Encode, i.e. contracting path | requires encode=True
2) Decode, i.e. expanding path | requires decode=True
3) Bottleneck | requires encode=False and decode=False
:param num_filters: number of filters required for the conv layers
:param kernel_size: h,w of the kernel
:param decode: whether the block will be used in the expanding path
:param encode: whether the block will be used in the contracting path
:param padding: how to pad the input
:param activation: the activation function to use
:param initializer: the weight initializer to use
:param batch_norm: whether to include batch normalization layers
:param name: the name of the block
"""
super().__init__(name=name)
self.conv1 = Conv2D(filters=num_filters, kernel_size=kernel_size,
padding=padding, activation=activation,
kernel_initializer=initializer)
self.conv2 = Conv2D(filters=num_filters, kernel_size=kernel_size,
padding=padding, activation=activation,
kernel_initializer=initializer)
if batch_norm:
self.bn1 = tf.keras.layers.BatchNormalization()
self.bn2 = tf.keras.layers.BatchNormalization()
if decode:
# half the x,y dimensions each pool step
self.pool = MaxPool2D(pool_size=(2, 2), strides=2)
elif encode:
# default axis is -1 => the filter axis
self.concat = tf.keras.layers.Concatenate()
# double the x,y dimensions each pool step, notice no activation
self.conv_trans = Conv2DTranspose(filters=num_filters,
kernel_size=2, strides=2,
padding=padding,
activation=activation,
kernel_initializer=initializer)
elif decode and encode:
raise ValueError("Decode and Encode can't both be True")
# Store user-defined input to be used in call step
self.batch_norm = batch_norm
self.decode = decode
self.encode = encode
[docs] def downblock(self, input_tensor: tf.Tensor) -> tf.Tensor:
"""
A block to be used in the contracting path of a U-Net.
:param input_tensor: the input to the downblock
:return: ouput before pooling, and the pooled output
"""
x = self.conv1(input_tensor)
if self.batch_norm:
x = self.bn1(x)
x = self.conv2(x)
if self.batch_norm:
x = self.bn2(x)
pool_out = self.pool(x)
return x, pool_out
[docs] def upblock(self,
input_tensor: tf.Tensor,
downblock_output: tf.Tensor) -> tf.Tensor:
"""
A block to be used in the expansive path of a U-Net. It requires as
input the output from the previous upblock and the output prior to
pooling from the corresponding downblock.
:param input_tensor: the output from the previous upblock
:param downblock_output: the output prior to pooling from the
corresponding downblock.
:return: output after upsampling and convolutions
"""
x = self.conv_trans(input_tensor)
x = self.concat([x, downblock_output])
x = self.conv1(x)
if self.batch_norm:
x = self.bn1(x)
x = self.conv2(x)
if self.batch_norm:
x = self.bn2(x)
return x
[docs] def bottleneck(self, input_tensor: tf.Tensor) -> tf.Tensor:
"""
The bottleneck block used in a U-Net
:param input_tensor: the output of the final downblock in the
contracting path
:return: output after convolutions
"""
x = self.conv1(input_tensor)
if self.batch_norm:
x = self.bn1(x)
x = self.conv2(x)
if self.batch_norm:
x = self.bn2(x)
return x
[docs] def call(self,
input_tensor: tf.Tensor,
downblock_output: tf.Tensor = None) -> tf.Tensor:
"""
Calls either a downblock, bottleneck blokc, or upblock depending on
the states of the decode and encode blocks.
:param input_tensor: the input to be given to the relevant block
:param downblock_output: the output prior to pooling from the
corresponding downblock; only required if an upblock is being used
:return: a tensor output of the relevant block
"""
if self.decode:
return self.downblock(input_tensor)
elif self.encode:
if downblock_output is None:
raise ValueError("Please provide output to use for the skip "
"connection")
return self.upblock(input_tensor, downblock_output)
else:
return self.bottleneck(input_tensor)
# *================================== U-Net ==================================*
[docs]class Unet(tf.keras.Model, _TfPnsMixin):
"""
A U-Net model
"""
def __init__(self,
output_channels: int,
activation: str = "relu",
initializer: str = "he_normal",
filters: int = 3):
"""
Instantiates a U-Net
:param output_channels: the number of output channels to use in the
final layer; this is related to the number of classes in the final
prediction
:param activation: the activation to use in all blocks apart from
the last
:param initializer: the initialiser to use in all blocks apart from
the last
:param filters: the filter multiple to use
"""
super().__init__()
if initializer == "he_normal":
initializer = tf.keras.initializers.he_normal(seed=3141)
else:
initializer = tf.keras.initializers.glorot_uniform(seed=3141)
if activation == "relu":
activation = tf.nn.relu
else:
activation = tf.nn.selu
self.activation = activation
# Contracting
# Layer 1
self.conv_down1 = UnetBlock(num_filters=8 * 2 ** filters,
kernel_size=3,
decode=True, activation=activation,
initializer=initializer, name="down1")
# Layer 2
self.conv_down2 = UnetBlock(num_filters=16 * 2 ** filters,
kernel_size=3,
decode=True, activation=activation,
initializer=initializer, name="down2")
# Layer 3
self.conv_down3 = UnetBlock(num_filters=32 * 2 ** filters,
kernel_size=3,
decode=True, activation=activation,
initializer=initializer, name="down3")
# Layer 4
self.conv_down4 = UnetBlock(num_filters=64 * 2 ** filters,
kernel_size=3,
decode=True, activation=activation,
initializer=initializer, name="down4")
# Bottleneck
self.conv_bottle = UnetBlock(num_filters=128 * 2 ** filters,
kernel_size=3, activation=activation,
initializer=initializer,
name="bottleneck")
# Expanding
# Layer 1
# No activation ... Since skip happens before the activation
self.conv_up1 = UnetBlock(num_filters=64 * 2 ** filters, kernel_size=3,
encode=True, activation=activation,
initializer=initializer,
name="up1")
# Layer 2
self.conv_up2 = UnetBlock(num_filters=32 * 2 ** filters, kernel_size=3,
encode=True, activation=activation,
initializer=initializer,
name="up2")
# Layer 3
self.conv_up3 = UnetBlock(num_filters=16 * 2 ** filters, kernel_size=3,
encode=True, activation=activation,
initializer=initializer,
name="up3")
# Layer 4
self.conv_up4 = UnetBlock(num_filters=8 * 2 ** filters, kernel_size=3,
encode=True, activation=activation,
initializer=initializer, name="up4")
# Output
self.output_layer = Conv2D(output_channels, 1, strides=1,
padding='same',
activation="sigmoid",
name="classification_layer")
[docs] def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""
Applies a U-Net model to the input.
:param inputs: an input image
:return: the output of a U-Net model
"""
# Contracting
# Layer 1
down1, x = self.conv_down1(inputs)
# Layer 2
down2, x = self.conv_down2(x)
# Layer 3
down3, x = self.conv_down3(x)
# Layer 4
down4, x = self.conv_down4(x)
# Bottleneck
x = self.conv_bottle(x)
# Expanding
# Layer 1
x = self.conv_up1(x, down4)
# Layer 2
x = self.conv_up2(x, down3)
# Layer 3
x = self.conv_up3(x, down2)
# Layer 4
x = self.conv_up4(x, down1)
x = self.output_layer(x)
return x
[docs] def model(self,
shape: Tuple[int, int, int] = (512, 512, 1)) -> tf.keras.Model:
"""
Returns a U-Net model as tf.keras.Model. This is a workaround to use
the functional api, which allows the model to be viewed.
:param shape: the shape of the input
:return: the tf.keras.Model instantiated using the functional api
"""
x = Input(shape=shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
[docs] def print_all_layers(self) -> None:
"""
Prints all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: None
"""
model_layers = self.model().layers
print(model_layers[0])
for layer in model_layers[1:-1]:
for l in layer.layers:
print(l)
print(model_layers[-1])
[docs] def get_all_layers(self) -> List[tf.keras.layers.Layer]:
"""
Returns all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: a list of layers
"""
model_layers = self.model().layers
layers = []
# layers.append(model_layers[0])
for layer in model_layers[1:-1]:
for l in layer.layers:
layers.append(l)
layers.append(model_layers[-1])
return layers
# *===========================================================================*