123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- """Example applications for image classifcation.
- Each function returns a pretrained ImageNet model.
- The models are based on keras.applications models and
- contain additionally pretrained patterns.
- The returned dictionary contains the following
- keys\: model, in, sm_out, out, image_shape, color_coding,
- preprocess_f, patterns.
- Function parameters\:
- :param load_weights: Download or access cached weights.
- :param load_patterns: Download or access cached patterns.
- """
- # todo: rename in, sm_out, out to input_tensors, output_tensors,
- # todo: softmax_output_tenors
- # Get Python six functionality:
- from __future__ import\
- absolute_import, print_function, division, unicode_literals
- from builtins import range
- ###############################################################################
- ###############################################################################
- ###############################################################################
- import keras.backend as K
- import keras.applications.resnet50
- import keras.applications.vgg16
- import keras.applications.vgg19
- import keras.applications.inception_v3
- import keras.applications.inception_resnet_v2
- import keras.applications.densenet
- import keras.applications.nasnet
- import keras.utils.data_utils
- import numpy as np
- import warnings
- from ..utils.keras import graph as kgraph
- __all__ = [
- "vgg16",
- "vgg19",
- "resnet50",
- "inception_v3",
- "inception_resnet_v2",
- "densenet121",
- "densenet169",
- "densenet201",
- "nasnet_large",
- "nasnet_mobile",
- ]
- ###############################################################################
- ###############################################################################
- ###############################################################################
- PATTERNS = {
- "vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": {
- "url": "https://www.dropbox.com/s/15lip81fzvbgkaa/vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1",
- "hash": "8c2abe648e116a93fd5027fab49177b0",
- },
- "vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": {
- "url": "https://www.dropbox.com/s/nc5empj78rfe9hm/vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1",
- "hash": "3258b6c64537156afe75ca7b3be44742",
- },
- }
- def _get_patterns_info(netname, pattern_type):
- if pattern_type is True:
- pattern_type = "relu"
- file_name = ("%s_pattern_type_%s_tf_dim_ordering_tf_kernels.npz" %
- (netname, pattern_type))
- return {"file_name": file_name,
- "url": PATTERNS[file_name]["url"],
- "hash": PATTERNS[file_name]["hash"]}
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def _prepare_keras_net(netname,
- clazz,
- image_shape,
- preprocess_f,
- preprocess_mode=None,
- color_coding="RGB",
- load_weights=False,
- load_patterns=False):
- net = {}
- net["name"] = netname
- net["image_shape"] = image_shape
- if K.image_data_format() == "channels_first":
- net["input_shape"] = [None, 3]+image_shape
- else:
- net["input_shape"] = [None]+image_shape+[3]
- weights = None
- if load_weights is True:
- weights = "imagenet"
- model = clazz(weights=weights,
- input_shape=tuple(net["input_shape"][1:]))
- net["model"] = model
- net["in"] = model.inputs
- net["sm_out"] = model.outputs
- net["out"] = kgraph.pre_softmax_tensors(model.outputs)
- net["color_coding"] = color_coding
- net["preprocess_f"] = preprocess_f
- net["input_range"] = {
- None: (-128, 128),
- "caffe": (-128, 128),
- "tf": (-1, 1),
- "torch": (-3, 3),
- }[preprocess_mode]
- net["patterns"] = None
- if load_patterns is not False:
- try:
- pattern_info = _get_patterns_info(netname, load_patterns)
- except KeyError:
- warnings.warn("There are no patterns for network '%s'." % netname)
- else:
- patterns_path = keras.utils.data_utils.get_file(
- pattern_info["file_name"],
- pattern_info["url"],
- cache_subdir="innvestigate_patterns",
- hash_algorithm="md5",
- file_hash=pattern_info["hash"])
- patterns_file = np.load(patterns_path)
- patterns = [patterns_file["arr_%i" % i]
- for i in range(len(patterns_file.keys()))]
- net["patterns"] = patterns
- return net
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def vgg16(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "vgg16",
- keras.applications.vgg16.VGG16,
- [224, 224],
- preprocess_f=keras.applications.vgg16.preprocess_input,
- preprocess_mode="caffe",
- color_coding="BGR",
- load_weights=load_weights,
- load_patterns=load_patterns)
- def vgg19(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "vgg19",
- keras.applications.vgg19.VGG19,
- [224, 224],
- preprocess_f=keras.applications.vgg19.preprocess_input,
- preprocess_mode="caffe",
- color_coding="BGR",
- load_weights=load_weights,
- load_patterns=load_patterns)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def resnet50(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "resnet50",
- keras.applications.resnet50.ResNet50,
- [224, 224],
- preprocess_f=keras.applications.resnet50.preprocess_input,
- preprocess_mode="caffe",
- color_coding="BGR",
- load_weights=load_weights,
- load_patterns=load_patterns)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def inception_v3(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "inception_v3",
- keras.applications.inception_v3.InceptionV3,
- [299, 299],
- preprocess_f=keras.applications.inception_v3.preprocess_input,
- preprocess_mode="tf",
- load_weights=load_weights,
- load_patterns=load_patterns)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def inception_resnet_v2(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "inception_resnet_v2",
- keras.applications.inception_resnet_v2.InceptionResNetV2,
- [299, 299],
- preprocess_f=keras.applications.inception_resnet_v2.preprocess_input,
- preprocess_mode="tf",
- load_weights=load_weights,
- load_patterns=load_patterns)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def densenet121(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "densenet121",
- keras.applications.densenet.DenseNet121,
- [224, 224],
- preprocess_f=keras.applications.densenet.preprocess_input,
- preprocess_mode="torch",
- load_weights=load_weights,
- load_patterns=load_patterns)
- def densenet169(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "densenet169",
- keras.applications.densenet.DenseNet169,
- [224, 224],
- preprocess_f=keras.applications.densenet.preprocess_input,
- preprocess_mode="torch",
- load_weights=load_weights,
- load_patterns=load_patterns)
- def densenet201(load_weights=False, load_patterns=False):
- return _prepare_keras_net(
- "densenet201",
- keras.applications.densenet.DenseNet201,
- [224, 224],
- preprocess_f=keras.applications.densenet.preprocess_input,
- preprocess_mode="torch",
- load_weights=load_weights,
- load_patterns=load_patterns)
- ###############################################################################
- ###############################################################################
- ###############################################################################
- def nasnet_large(load_weights=False, load_patterns=False):
- if K.image_data_format() == "channels_first":
- raise Exception("NASNet is not available for channels first.")
- return _prepare_keras_net(
- "nasnet_large",
- keras.applications.nasnet.NASNetLarge,
- [331, 331],
- color_coding="BGR",
- preprocess_f=keras.applications.nasnet.preprocess_input,
- preprocess_mode="tf",
- load_weights=load_weights,
- load_patterns=load_patterns)
- def nasnet_mobile(load_weights=False, load_patterns=False):
- if K.image_data_format() == "channels_first":
- raise Exception("NASNet is not available for channels first.")
- return _prepare_keras_net(
- "nasnet_mobile",
- keras.applications.nasnet.NASNetMobile,
- [224, 224],
- color_coding="BGR",
- preprocess_f=keras.applications.nasnet.preprocess_input,
- preprocess_mode="tf",
- load_weights=load_weights,
- load_patterns=load_patterns)
|