imagenet.py 9.9 KB


  1. """Example applications for image classifcation.
  2. Each function returns a pretrained ImageNet model.
  3. The models are based on keras.applications models and
  4. contain additionally pretrained patterns.
  5. The returned dictionary contains the following
  6. keys\: model, in, sm_out, out, image_shape, color_coding,
  7. preprocess_f, patterns.
  8. Function parameters\:
  9. :param load_weights: Download or access cached weights.
  10. :param load_patterns: Download or access cached patterns.
  11. """
  12. # todo: rename in, sm_out, out to input_tensors, output_tensors,
  13. # todo: softmax_output_tenors
  14. # Get Python six functionality:
  15. from __future__ import\
  16. absolute_import, print_function, division, unicode_literals
  17. from builtins import range
  18. ###############################################################################
  19. ###############################################################################
  20. ###############################################################################
  21. import keras.backend as K
  22. import keras.applications.resnet50
  23. import keras.applications.vgg16
  24. import keras.applications.vgg19
  25. import keras.applications.inception_v3
  26. import keras.applications.inception_resnet_v2
  27. import keras.applications.densenet
  28. import keras.applications.nasnet
  29. import keras.utils.data_utils
  30. import numpy as np
  31. import warnings
  32. from ..utils.keras import graph as kgraph
  33. __all__ = [
  34. "vgg16",
  35. "vgg19",
  36. "resnet50",
  37. "inception_v3",
  38. "inception_resnet_v2",
  39. "densenet121",
  40. "densenet169",
  41. "densenet201",
  42. "nasnet_large",
  43. "nasnet_mobile",
  44. ]
  45. ###############################################################################
  46. ###############################################################################
  47. ###############################################################################
  48. PATTERNS = {
  49. "vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": {
  50. "url": "https://www.dropbox.com/s/15lip81fzvbgkaa/vgg16_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1",
  51. "hash": "8c2abe648e116a93fd5027fab49177b0",
  52. },
  53. "vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz": {
  54. "url": "https://www.dropbox.com/s/nc5empj78rfe9hm/vgg19_pattern_type_relu_tf_dim_ordering_tf_kernels.npz?dl=1",
  55. "hash": "3258b6c64537156afe75ca7b3be44742",
  56. },
  57. }
  58. def _get_patterns_info(netname, pattern_type):
  59. if pattern_type is True:
  60. pattern_type = "relu"
  61. file_name = ("%s_pattern_type_%s_tf_dim_ordering_tf_kernels.npz" %
  62. (netname, pattern_type))
  63. return {"file_name": file_name,
  64. "url": PATTERNS[file_name]["url"],
  65. "hash": PATTERNS[file_name]["hash"]}
  66. ###############################################################################
  67. ###############################################################################
  68. ###############################################################################
  69. def _prepare_keras_net(netname,
  70. clazz,
  71. image_shape,
  72. preprocess_f,
  73. preprocess_mode=None,
  74. color_coding="RGB",
  75. load_weights=False,
  76. load_patterns=False):
  77. net = {}
  78. net["name"] = netname
  79. net["image_shape"] = image_shape
  80. if K.image_data_format() == "channels_first":
  81. net["input_shape"] = [None, 3]+image_shape
  82. else:
  83. net["input_shape"] = [None]+image_shape+[3]
  84. weights = None
  85. if load_weights is True:
  86. weights = "imagenet"
  87. model = clazz(weights=weights,
  88. input_shape=tuple(net["input_shape"][1:]))
  89. net["model"] = model
  90. net["in"] = model.inputs
  91. net["sm_out"] = model.outputs
  92. net["out"] = kgraph.pre_softmax_tensors(model.outputs)
  93. net["color_coding"] = color_coding
  94. net["preprocess_f"] = preprocess_f
  95. net["input_range"] = {
  96. None: (-128, 128),
  97. "caffe": (-128, 128),
  98. "tf": (-1, 1),
  99. "torch": (-3, 3),
  100. }[preprocess_mode]
  101. net["patterns"] = None
  102. if load_patterns is not False:
  103. try:
  104. pattern_info = _get_patterns_info(netname, load_patterns)
  105. except KeyError:
  106. warnings.warn("There are no patterns for network '%s'." % netname)
  107. else:
  108. patterns_path = keras.utils.data_utils.get_file(
  109. pattern_info["file_name"],
  110. pattern_info["url"],
  111. cache_subdir="innvestigate_patterns",
  112. hash_algorithm="md5",
  113. file_hash=pattern_info["hash"])
  114. patterns_file = np.load(patterns_path)
  115. patterns = [patterns_file["arr_%i" % i]
  116. for i in range(len(patterns_file.keys()))]
  117. net["patterns"] = patterns
  118. return net
  119. ###############################################################################
  120. ###############################################################################
  121. ###############################################################################
  122. def vgg16(load_weights=False, load_patterns=False):
  123. return _prepare_keras_net(
  124. "vgg16",
  125. keras.applications.vgg16.VGG16,
  126. [224, 224],
  127. preprocess_f=keras.applications.vgg16.preprocess_input,
  128. preprocess_mode="caffe",
  129. color_coding="BGR",
  130. load_weights=load_weights,
  131. load_patterns=load_patterns)
  132. def vgg19(load_weights=False, load_patterns=False):
  133. return _prepare_keras_net(
  134. "vgg19",
  135. keras.applications.vgg19.VGG19,
  136. [224, 224],
  137. preprocess_f=keras.applications.vgg19.preprocess_input,
  138. preprocess_mode="caffe",
  139. color_coding="BGR",
  140. load_weights=load_weights,
  141. load_patterns=load_patterns)
  142. ###############################################################################
  143. ###############################################################################
  144. ###############################################################################
  145. def resnet50(load_weights=False, load_patterns=False):
  146. return _prepare_keras_net(
  147. "resnet50",
  148. keras.applications.resnet50.ResNet50,
  149. [224, 224],
  150. preprocess_f=keras.applications.resnet50.preprocess_input,
  151. preprocess_mode="caffe",
  152. color_coding="BGR",
  153. load_weights=load_weights,
  154. load_patterns=load_patterns)
  155. ###############################################################################
  156. ###############################################################################
  157. ###############################################################################
  158. def inception_v3(load_weights=False, load_patterns=False):
  159. return _prepare_keras_net(
  160. "inception_v3",
  161. keras.applications.inception_v3.InceptionV3,
  162. [299, 299],
  163. preprocess_f=keras.applications.inception_v3.preprocess_input,
  164. preprocess_mode="tf",
  165. load_weights=load_weights,
  166. load_patterns=load_patterns)
  167. ###############################################################################
  168. ###############################################################################
  169. ###############################################################################
  170. def inception_resnet_v2(load_weights=False, load_patterns=False):
  171. return _prepare_keras_net(
  172. "inception_resnet_v2",
  173. keras.applications.inception_resnet_v2.InceptionResNetV2,
  174. [299, 299],
  175. preprocess_f=keras.applications.inception_resnet_v2.preprocess_input,
  176. preprocess_mode="tf",
  177. load_weights=load_weights,
  178. load_patterns=load_patterns)
  179. ###############################################################################
  180. ###############################################################################
  181. ###############################################################################
  182. def densenet121(load_weights=False, load_patterns=False):
  183. return _prepare_keras_net(
  184. "densenet121",
  185. keras.applications.densenet.DenseNet121,
  186. [224, 224],
  187. preprocess_f=keras.applications.densenet.preprocess_input,
  188. preprocess_mode="torch",
  189. load_weights=load_weights,
  190. load_patterns=load_patterns)
  191. def densenet169(load_weights=False, load_patterns=False):
  192. return _prepare_keras_net(
  193. "densenet169",
  194. keras.applications.densenet.DenseNet169,
  195. [224, 224],
  196. preprocess_f=keras.applications.densenet.preprocess_input,
  197. preprocess_mode="torch",
  198. load_weights=load_weights,
  199. load_patterns=load_patterns)
  200. def densenet201(load_weights=False, load_patterns=False):
  201. return _prepare_keras_net(
  202. "densenet201",
  203. keras.applications.densenet.DenseNet201,
  204. [224, 224],
  205. preprocess_f=keras.applications.densenet.preprocess_input,
  206. preprocess_mode="torch",
  207. load_weights=load_weights,
  208. load_patterns=load_patterns)
  209. ###############################################################################
  210. ###############################################################################
  211. ###############################################################################
  212. def nasnet_large(load_weights=False, load_patterns=False):
  213. if K.image_data_format() == "channels_first":
  214. raise Exception("NASNet is not available for channels first.")
  215. return _prepare_keras_net(
  216. "nasnet_large",
  217. keras.applications.nasnet.NASNetLarge,
  218. [331, 331],
  219. color_coding="BGR",
  220. preprocess_f=keras.applications.nasnet.preprocess_input,
  221. preprocess_mode="tf",
  222. load_weights=load_weights,
  223. load_patterns=load_patterns)
  224. def nasnet_mobile(load_weights=False, load_patterns=False):
  225. if K.image_data_format() == "channels_first":
  226. raise Exception("NASNet is not available for channels first.")
  227. return _prepare_keras_net(
  228. "nasnet_mobile",
  229. keras.applications.nasnet.NASNetMobile,
  230. [224, 224],
  231. color_coding="BGR",
  232. preprocess_f=keras.applications.nasnet.preprocess_input,
  233. preprocess_mode="tf",
  234. load_weights=load_weights,
  235. load_patterns=load_patterns)