# W-Net structure copied from the paper:
# The Little W-Net That Could: State-of-the-Art Retinal Vessel Segmentation
# with Minimalistic Models by Galdran et al.
# Link: https://arxiv.org/abs/2009.01907
from typing import List, Tuple
import tensorflow as tf
from tensorflow.keras.layers import (Conv2D, MaxPool2D, Conv2DTranspose,
Input, BatchNormalization, Concatenate)
from src.pipelines.tensorflow_v2.helpers.train_test import _TfPnsMixin
from src.pipelines.tensorflow_v2.models.unet_resnet import ResBlock
# *============================ Conv Bridge Block ============================*
[docs]class ConvBridgeBlock(tf.keras.Model):
"""
A Convolutional Bridge Block to be used in a W-Net
"""
def __init__(self,
channels: int,
activation: str,
initializer: str):
"""
Instantiates a ConvBridgeBlock object
:param channels: number of filters required for the conv layers
:param activation: the activation function to use
:param initializer: the weight initializer to use
"""
super().__init__()
self.conv = Conv2D(filters=channels, kernel_size=3, strides=1,
padding='same',
kernel_initializer=initializer)
self.bn = BatchNormalization()
self.activation = activation
[docs] def call(self, x: tf.Tensor) -> tf.Tensor:
"""
Applies a ConvBridgeBlock to the an input
:param x: the input to apply the ConvBridgeBlock to
:return: the output of the ConvBridgeBlock
"""
x1 = self.conv(x)
x1 = self.activation(x1)
x1 = self.bn(x1)
return x1
# *=============================== Mini U-Net ================================*
[docs]class MiniUnet(tf.keras.Model):
"""
A mini U-Net, two of these are joined to make a W-Net model
"""
def __init__(self,
output_channels: int,
activation: str = "relu",
initializer: str = "he_normal",
filters: int = 0):
"""
Instantiates a MiniUnet object.
:param output_channels: the number of output channels to use in the
final layer; this is related to the number of classes in the final
prediction
:param activation: the activation to use in all blocks apart from
the last
:param initializer: the initialiser to use in all blocks apart from
the last
:param filters: the filter multiple to use
"""
super().__init__()
if initializer == "he_normal":
initializer = tf.keras.initializers.he_normal(seed=3141)
else:
initializer = tf.keras.initializers.glorot_uniform(seed=3141)
# Components
# Contracting
self.res_down1 = ResBlock(8 * 2 ** filters, decode=True,
activation=activation,
initializer=initializer)
self.pool1 = MaxPool2D(pool_size=2, strides=2)
self.res_down2 = ResBlock(16 * 2 ** filters, decode=True,
activation=activation,
initializer=initializer)
self.pool2 = MaxPool2D(pool_size=2, strides=2)
self.res_bottle = ResBlock(32 * 2 ** filters, decode=True,
activation=activation,
initializer=initializer)
# Expanding
self.trans_conv1 = Conv2DTranspose(filters=16 * 2 ** filters,
kernel_size=2, strides=2,
padding='same',
activation=activation,
kernel_initializer=initializer)
self.res_up1 = ResBlock(16 * 2 ** filters, activation=activation,
initializer=initializer)
self.bridge1 = ConvBridgeBlock(16 * 2 ** filters,
activation=activation,
initializer=initializer)
self.concat1 = Concatenate()
self.trans_conv2 = Conv2DTranspose(filters=8 * 2 ** filters,
kernel_size=2,
activation=activation,
strides=2, padding='same',
kernel_initializer=initializer)
self.res_up2 = ResBlock(8 * 2 ** filters, activation=activation,
initializer=initializer)
self.bridge2 = ConvBridgeBlock(8 * 2 ** filters,
activation=activation,
initializer=initializer)
self.concat2 = Concatenate()
# Output
self.output_layer = Conv2D(output_channels, 1, strides=1,
padding='same',
activation="sigmoid",
name="classification_layer")
[docs] def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""
Applies a mini U-Net to the input.
:param inputs: an input image
:return: the output of a mini U-Net
"""
# Contracting
down1 = self.res_down1(inputs)
x = self.pool1(down1)
down2 = self.res_down2(x)
x = self.pool2(down2)
# Bottle
x = self.res_bottle(x)
# Expanding
x = self.trans_conv1(x)
x = self.res_up1(x)
# Why?
down2 = self.bridge1(down2)
x = self.concat1([x, down2])
x = self.trans_conv2(x)
x = self.res_up2(x)
down1 = self.bridge2(down1)
x = self.concat2([x, down1])
x = self.output_layer(x)
return x
[docs] def model(self,
shape: Tuple[int, int, int] = (512, 512, 1)) -> tf.keras.Model:
"""
Returns a U-Net model as tf.keras.Model. This is a workaround to use
the functional api, which allows the model to be viewed.
:param shape: the shape of the input
:return: the tf.keras.Model instantiated using the functional api
"""
x = Input(shape=shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
[docs] def print_all_layers(self) -> None:
"""
Prints all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: None
"""
model_layers = self.model().layers
for layer in model_layers:
try:
for l in layer.layers:
print(l)
except:
print(layer)
[docs] def get_all_layers(self) -> List[tf.keras.layers.Layer]:
"""
Returns all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: a list of layers
"""
model_layers = self.model().layers
layers = []
for layer in model_layers:
try:
for l in layer.layers:
layers.append(l)
except:
layers.append(model_layers[-1])
return layers
# *================================== W-Net ==================================*
[docs]class WNet(tf.keras.Model, _TfPnsMixin):
"""
A W-Net model. This model combines two Mini U-Nets where the prediction of
the first Mini U-Net is concatenated to the first
"""
def __init__(self,
output_channels: int,
activation: str = "relu",
initializer: str = "he_normal",
filters: int = 0):
"""
Instantiates a WNet object.
:param output_channels: the number of output channels to use in the
final layer; this is related to the number of classes in the final
prediction
:param activation: the activation to use in all blocks apart from
the last
:param initializer: the initialiser to use in all blocks apart from
the last
:param filters: the filter multiple to use
"""
super().__init__()
if activation == "relu":
activation = tf.nn.relu
else:
activation = tf.nn.selu
self.activation = activation
# Channels set for binary prediction
self.unet1 = MiniUnet(output_channels, activation, initializer,
filters)
self.unet2 = MiniUnet(output_channels, activation, initializer,
filters)
self.concat = Concatenate()
[docs] def call(self,
input_tensor: tf.Tensor,
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Applies a W-Net model to an input image, which is a call of two
sequent mini U-Nets. If the W-Net is not training then only the
output of the second mini U-Net is returned.
:param input_tensor: an input image
:param training: whether the W-Net is being applied to a training
sample
:return: either the ouput of both mini U-Nets or the output
of the second mini U-Net
"""
x1 = self.unet1(input_tensor)
x = self.concat([input_tensor, x1])
x2 = self.unet2(x)
if not training:
return x2
return x1, x2
[docs] def model(self,
shape: Tuple[int, int, int] = (512, 512, 1)) -> tf.keras.Model:
"""
Returns a U-Net model as tf.keras.Model. This is a workaround to use
the functional api, which allows the model to be viewed.
:param shape: the shape of the input
:return: the tf.keras.Model instantiated using the functional api
"""
x = Input(shape=shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
[docs] def print_all_layers(self) -> None:
"""
Prints all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: None
"""
model_layers = self.model().layers
for layer in model_layers:
try:
for l in layer.layers:
print(l)
except:
print(layer)
[docs] def get_all_layers(self) -> List[tf.keras.layers.Layer]:
"""
Returns all the layers in the model, including the layers in
the subclasses which make up the model. This uses the model()
workaround function.
:return: a list of layers
"""
model_layers = self.model().layers
layers = []
for layer in model_layers:
try:
for l in layer.layers:
layers.append(l)
except:
layers.append(model_layers[-1])
return layers
# *===========================================================================*