Source code for src.actions.prediction

import json
import logging
from typing import Dict, List, Tuple

from tensorflow import keras

from src.data_model.data_model import LeafSequence, MaskSequence
from src.helpers.utilities import classification_report
from src.pipelines.tensorflow_v2.losses.custom_losses import *
from src.pipelines.tensorflow_v2.models.unet import Unet
from src.pipelines.tensorflow_v2.models.unet_resnet import UnetResnet
from src.pipelines.tensorflow_v2.models.wnet import WNet

LOGGER = logging.getLogger(__name__)


# *================================ get model ================================*
[docs]def get_workaround_details(compilation_dict: Dict): """ Creates a model using the saved compliation dict from training the TF model. :param compilation_dict: the compilation dict with the details to recreate the tensorflow model :return: a TF model, a TF loss, a TF optimiser, and TF metrics .. note:: This workaround is used since normal model saving for TF subclassed models did not work at the time of writing. """ # model: if compilation_dict["model"] == "unet": model = Unet(1, compilation_dict["activation"], compilation_dict["initializer"], compilation_dict["filters"]) elif compilation_dict["model"] == "unet_resnet": model = UnetResnet(1, compilation_dict["activation"], compilation_dict["initializer"], compilation_dict["filters"]) elif compilation_dict["model"] == "wnet": model = WNet(1, compilation_dict["activation"], compilation_dict["initializer"], compilation_dict["filters"]) else: raise ValueError("Please provide a valid answer for model choice, " "options are unet, unet_resnet, or wnet") if compilation_dict["loss"] == "bce": loss = keras.losses.binary_crossentropy elif compilation_dict["loss"] == "wce": loss = WeightedCE(compilation_dict["loss_weight"]) elif compilation_dict["loss"] == "focal": loss = FocalLoss(compilation_dict["loss_weight"]) elif compilation_dict["loss"] == "dice": loss = SoftDiceLoss() else: raise ValueError("Please provide a valid answer for loss choice, " "options are bce, wce, focal, or dice") if compilation_dict["opt"] == "adam": opt = keras.optimizers.Adam elif compilation_dict["opt"] == "sgd": opt = keras.optimizers.SGD else: raise ValueError("Please provide a valid answer for optimiser choice, " "options are adam or sgd") metrics = [keras.metrics.BinaryAccuracy(name='accuracy'), keras.metrics.Precision(name='precision'), keras.metrics.Recall(name='recall')] return model, loss, opt, metrics
# *=============================== prediction ================================*
[docs]def predict_tensorflow(lseq_list: List[LeafSequence], model_weight_path: str, leaf_shape: Tuple[int, int], cr_csv_list: List[str] = "", mseq_list: List[MaskSequence] = None, threshold: float = 0.5, format_dict: Dict = None) -> None: """ Makes predictions for the images in each LeafSequence in the list of leaf sequences. The predictions are saved using a default path. If classification report csv save paths are provided, classification reports are generated and saved. :param lseq_list: a list of LeafSequence to make predictions for :param model_weight_path: the path to saved tf model to use to make predictions :param leaf_shape: the shape of leaf used when training the saved model :param cr_csv_list: file paths of where the classification reports should be saved; if this is provided, it must be the same length as the lseq_list :param mseq_list: the list of mseqs to use when generating the classification report, this must be provided if cr_csv_list is provided, and it must be the same length as lseq_list :param threshold: the threshold to use when saving predictions; i.e. a pixel is saved as an embolism if p(embolism) > threshold :param format_dict: the format to use when loading the LeafSequence Leaf images; these images are used as inputs for the predictions :return: None """ if format_dict is None: format_dict = {} with open(model_weight_path + "compilation.json", "r") as json_file: compilation_dict = json.load(json_file) model, loss, opt, metrics = get_workaround_details(compilation_dict) model.load_workaround(leaf_shape, leaf_shape, loss, opt(float(compilation_dict["lr"])), metrics, model_weight_path) memory_saving = True cr_csv_list = cr_csv_list.split(";") if cr_csv_list[0]: memory_saving = False for i, lseq in enumerate(lseq_list): lseq.predict_leaf_sequence(model, leaf_shape[0], leaf_shape[1], memory_saving=memory_saving, leaf_shape=leaf_shape, threshold=threshold, **format_dict) if cr_csv_list[0]: mseq_list[i].load_extracted_images(load_image=True) temp_pred_list = [] temp_mask_list = [] for leaf, mask in zip(lseq.image_objects, mseq_list[i].image_objects): temp_pred_list.append(leaf.prediction_array / 255.0) temp_mask_list.append(mask.image_array / 255.0) # save memory del leaf.image_array del mask.image_array _ = classification_report(temp_pred_list, temp_mask_list, save_path=cr_csv_list[i])