U-Net¶
-
class
src.pipelines.tensorflow_v2.models.unet.UnetBlock(num_filters, kernel_size, decode=False, encode=False, batch_norm=False, padding='same', activation='relu', initializer='he_normal', name='')[source]¶ -
downblock(input_tensor)[source]¶ A block to be used in the contracting path of a U-Net.
- Parameters
input_tensor (
Tensor) – the input to the downblock- Return type
Tensor- Returns
ouput before pooling, and the pooled output
-
upblock(input_tensor, downblock_output)[source]¶ A block to be used in the expansive path of a U-Net. It requires as input the output from the previous upblock and the output prior to pooling from the corresponding downblock.
- Parameters
input_tensor (
Tensor) – the output from the previous upblockdownblock_output (
Tensor) – the output prior to pooling from the corresponding downblock.
- Return type
Tensor- Returns
output after upsampling and convolutions
-
bottleneck(input_tensor)[source]¶ The bottleneck block used in a U-Net
- Parameters
input_tensor (
Tensor) – the output of the final downblock in the contracting path- Return type
Tensor- Returns
output after convolutions
-
call(input_tensor, downblock_output=None)[source]¶ Calls either a downblock, bottleneck blokc, or upblock depending on the states of the decode and encode blocks.
- Parameters
input_tensor (
Tensor) – the input to be given to the relevant blockdownblock_output (
Optional[Tensor]) – the output prior to pooling from the corresponding downblock; only required if an upblock is being used
- Return type
Tensor- Returns
a tensor output of the relevant block
-
-
class
src.pipelines.tensorflow_v2.models.unet.Unet(output_channels, activation='relu', initializer='he_normal', filters=3)[source]¶ A U-Net model
-
call(inputs)[source]¶ Applies a U-Net model to the input.
- Parameters
inputs (
Tensor) – an input image- Return type
Tensor- Returns
the output of a U-Net model
-
model(shape=(512, 512, 1))[source]¶ Returns a U-Net model as tf.keras.Model. This is a workaround to use the functional api, which allows the model to be viewed.
- Parameters
shape (
Tuple[int,int,int]) – the shape of the input- Return type
Model- Returns
the tf.keras.Model instantiated using the functional api
-