"""Example applications for image classifcation. Each function returns a pretrained ImageNet model. The models are based on keras.applications models and contain additionally pretrained patterns. The returned dictionary contains the following keys\: model, in, sm_out, out, image_shape, color_coding, preprocess_f, patterns. Function parameters\: :param load_weights: Download or access cached weights. :param load_patterns: Download or access cached patterns. """ # todo: rename in, sm_out, out to input_tensors, output_tensors, # todo: softmax_output_tenors # Get Python six functionality: from __future__ import\ absolute_import, print_function, division, unicode_literals from builtins import range ############################################################################### ############################################################################### ############################################################################### import keras.backend as K import keras.applications.resnet50 import keras.applications.vgg16 import keras.applications.vgg19 import keras.applications.inception_v3 import keras.applications.inception_resnet_v2 import keras.applications.densenet import keras.applications.nasnet import keras.utils.data_utils import numpy as np import warnings from ..utils.keras import graph as kgraph __all__ = [ "vgg16", "vgg19", "resnet50", "inception_v3", "inception_resnet_v2", "densenet121", "densenet169", "densenet201", "nasnet_large", "nasnet_mobile", ] ############################################################################### ############################################################################### ############################################################################### PATTERNS = { "vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": { "url": "https://www.dropbox.com/s/15lip81fzvbgkaa/vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1", "hash": "8c2abe648e116a93fd5027fab49177b0", }, "vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": { "url": "https://www.dropbox.com/s/nc5empj78rfe9hm/vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1", "hash": "3258b6c64537156afe75ca7b3be44742", }, } def _get_patterns_info(netname, pattern_type): if pattern_type is True: pattern_type = "relu" file_name = ("%s_pattern_type_%s_tf_dim_ordering_tf_kernels.npz" % (netname, pattern_type)) return {"file_name": file_name, "url": PATTERNS[file_name]["url"], "hash": PATTERNS[file_name]["hash"]} ############################################################################### ############################################################################### ############################################################################### def _prepare_keras_net(netname, clazz, image_shape, preprocess_f, preprocess_mode=None, color_coding="RGB", load_weights=False, load_patterns=False): net = {} net["name"] = netname net["image_shape"] = image_shape if K.image_data_format() == "channels_first": net["input_shape"] = [None, 3]+image_shape else: net["input_shape"] = [None]+image_shape+[3] weights = None if load_weights is True: weights = "imagenet" model = clazz(weights=weights, input_shape=tuple(net["input_shape"][1:])) net["model"] = model net["in"] = model.inputs net["sm_out"] = model.outputs net["out"] = kgraph.pre_softmax_tensors(model.outputs) net["color_coding"] = color_coding net["preprocess_f"] = preprocess_f net["input_range"] = { None: (-128, 128), "caffe": (-128, 128), "tf": (-1, 1), "torch": (-3, 3), }[preprocess_mode] net["patterns"] = None if load_patterns is not False: try: pattern_info = _get_patterns_info(netname, load_patterns) except KeyError: warnings.warn("There are no patterns for model '%s'." % netname) else: patterns_path = keras.utils.data_utils.get_file( pattern_info["file_name"], pattern_info["url"], cache_subdir="innvestigate_patterns", hash_algorithm="md5", file_hash=pattern_info["hash"]) patterns_file = np.load(patterns_path) patterns = [patterns_file["arr_%i" % i] for i in range(len(patterns_file.keys()))] net["patterns"] = patterns return net ############################################################################### ############################################################################### ############################################################################### def vgg16(load_weights=False, load_patterns=False): return _prepare_keras_net( "vgg16", keras.applications.vgg16.VGG16, [224, 224], preprocess_f=keras.applications.vgg16.preprocess_input, preprocess_mode="caffe", color_coding="BGR", load_weights=load_weights, load_patterns=load_patterns) def vgg19(load_weights=False, load_patterns=False): return _prepare_keras_net( "vgg19", keras.applications.vgg19.VGG19, [224, 224], preprocess_f=keras.applications.vgg19.preprocess_input, preprocess_mode="caffe", color_coding="BGR", load_weights=load_weights, load_patterns=load_patterns) ############################################################################### ############################################################################### ############################################################################### def resnet50(load_weights=False, load_patterns=False): return _prepare_keras_net( "resnet50", keras.applications.resnet50.ResNet50, [224, 224], preprocess_f=keras.applications.resnet50.preprocess_input, preprocess_mode="caffe", color_coding="BGR", load_weights=load_weights, load_patterns=load_patterns) ############################################################################### ############################################################################### ############################################################################### def inception_v3(load_weights=False, load_patterns=False): return _prepare_keras_net( "inception_v3", keras.applications.inception_v3.InceptionV3, [299, 299], preprocess_f=keras.applications.inception_v3.preprocess_input, preprocess_mode="tf", load_weights=load_weights, load_patterns=load_patterns) ############################################################################### ############################################################################### ############################################################################### def inception_resnet_v2(load_weights=False, load_patterns=False): return _prepare_keras_net( "inception_resnet_v2", keras.applications.inception_resnet_v2.InceptionResNetV2, [299, 299], preprocess_f=keras.applications.inception_resnet_v2.preprocess_input, preprocess_mode="tf", load_weights=load_weights, load_patterns=load_patterns) ############################################################################### ############################################################################### ############################################################################### def densenet121(load_weights=False, load_patterns=False): return _prepare_keras_net( "densenet121", keras.applications.densenet.DenseNet121, [224, 224], preprocess_f=keras.applications.densenet.preprocess_input, preprocess_mode="torch", load_weights=load_weights, load_patterns=load_patterns) def densenet169(load_weights=False, load_patterns=False): return _prepare_keras_net( "densenet169", keras.applications.densenet.DenseNet169, [224, 224], preprocess_f=keras.applications.densenet.preprocess_input, preprocess_mode="torch", load_weights=load_weights, load_patterns=load_patterns) def densenet201(load_weights=False, load_patterns=False): return _prepare_keras_net( "densenet201", keras.applications.densenet.DenseNet201, [224, 224], preprocess_f=keras.applications.densenet.preprocess_input, preprocess_mode="torch", load_weights=load_weights, load_patterns=load_patterns) ############################################################################### ############################################################################### ############################################################################### def nasnet_large(load_weights=False, load_patterns=False): if K.image_data_format() == "channels_first": raise Exception("NASNet is not available for channels first.") return _prepare_keras_net( "nasnet_large", keras.applications.nasnet.NASNetLarge, [331, 331], color_coding="BGR", preprocess_f=keras.applications.nasnet.preprocess_input, preprocess_mode="tf", load_weights=load_weights, load_patterns=load_patterns) def nasnet_mobile(load_weights=False, load_patterns=False): if K.image_data_format() == "channels_first": raise Exception("NASNet is not available for channels first.") return _prepare_keras_net( "nasnet_mobile", keras.applications.nasnet.NASNetMobile, [224, 224], color_coding="BGR", preprocess_f=keras.applications.nasnet.preprocess_input, preprocess_mode="tf", load_weights=load_weights, load_patterns=load_patterns)