checks.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import inspect
  8. import keras.engine.topology
  9. import keras.layers
  10. import keras.layers.advanced_activations
  11. import keras.layers.convolutional
  12. import keras.layers.convolutional_recurrent
  13. import keras.layers.core
  14. import keras.layers.cudnn_recurrent
  15. import keras.layers.embeddings
  16. import keras.layers.local
  17. import keras.layers.noise
  18. import keras.layers.normalization
  19. import keras.layers.pooling
  20. import keras.layers.recurrent
  21. import keras.layers.wrappers
  22. import keras.legacy.layers
  23. # Prevents circular imports.
  24. def get_kgraph():
  25. from . import graph as kgraph
  26. return kgraph
  27. __all__ = [
  28. "get_current_layers",
  29. "get_known_layers",
  30. "get_activation_search_safe_layers",
  31. "contains_activation",
  32. "contains_kernel",
  33. "only_relu_activation",
  34. "is_network",
  35. "is_convnet_layer",
  36. "is_relu_convnet_layer",
  37. "is_average_pooling",
  38. "is_max_pooling",
  39. "is_input_layer",
  40. "is_batch_normalization_layer",
  41. "is_embedding_layer"
  42. ]
  43. ###############################################################################
  44. ###############################################################################
  45. ###############################################################################
  46. def get_current_layers():
  47. """
  48. Returns a list of currently available layers in Keras.
  49. """
  50. class_set = set([(getattr(keras.layers, name), name)
  51. for name in dir(keras.layers)
  52. if (inspect.isclass(getattr(keras.layers, name)) and
  53. issubclass(getattr(keras.layers, name),
  54. keras.engine.topology.Layer))])
  55. return [x[1] for x in sorted((str(x[0]), x[1]) for x in class_set)]
  56. def get_known_layers():
  57. """
  58. Returns a list of keras layer we are aware of.
  59. """
  60. # Inside function to not break import if Keras changes.
  61. KNOWN_LAYERS = (
  62. keras.engine.topology.InputLayer,
  63. keras.layers.advanced_activations.ELU,
  64. keras.layers.advanced_activations.LeakyReLU,
  65. keras.layers.advanced_activations.PReLU,
  66. keras.layers.advanced_activations.Softmax,
  67. keras.layers.advanced_activations.ThresholdedReLU,
  68. keras.layers.convolutional.Conv1D,
  69. keras.layers.convolutional.Conv2D,
  70. keras.layers.convolutional.Conv2DTranspose,
  71. keras.layers.convolutional.Conv3D,
  72. keras.layers.convolutional.Conv3DTranspose,
  73. keras.layers.convolutional.Cropping1D,
  74. keras.layers.convolutional.Cropping2D,
  75. keras.layers.convolutional.Cropping3D,
  76. keras.layers.convolutional.SeparableConv1D,
  77. keras.layers.convolutional.SeparableConv2D,
  78. keras.layers.convolutional.UpSampling1D,
  79. keras.layers.convolutional.UpSampling2D,
  80. keras.layers.convolutional.UpSampling3D,
  81. keras.layers.convolutional.ZeroPadding1D,
  82. keras.layers.convolutional.ZeroPadding2D,
  83. keras.layers.convolutional.ZeroPadding3D,
  84. keras.layers.convolutional_recurrent.ConvLSTM2D,
  85. keras.layers.convolutional_recurrent.ConvRecurrent2D,
  86. keras.layers.core.Activation,
  87. keras.layers.core.ActivityRegularization,
  88. keras.layers.core.Dense,
  89. keras.layers.core.Dropout,
  90. keras.layers.core.Flatten,
  91. keras.layers.core.Lambda,
  92. keras.layers.core.Masking,
  93. keras.layers.core.Permute,
  94. keras.layers.core.RepeatVector,
  95. keras.layers.core.Reshape,
  96. keras.layers.core.SpatialDropout1D,
  97. keras.layers.core.SpatialDropout2D,
  98. keras.layers.core.SpatialDropout3D,
  99. keras.layers.cudnn_recurrent.CuDNNGRU,
  100. keras.layers.cudnn_recurrent.CuDNNLSTM,
  101. keras.layers.embeddings.Embedding,
  102. keras.layers.local.LocallyConnected1D,
  103. keras.layers.local.LocallyConnected2D,
  104. keras.layers.Add,
  105. keras.layers.Average,
  106. keras.layers.Concatenate,
  107. keras.layers.Dot,
  108. keras.layers.Maximum,
  109. keras.layers.Minimum,
  110. keras.layers.Multiply,
  111. keras.layers.Subtract,
  112. keras.layers.noise.AlphaDropout,
  113. keras.layers.noise.GaussianDropout,
  114. keras.layers.noise.GaussianNoise,
  115. keras.layers.normalization.BatchNormalization,
  116. keras.layers.pooling.AveragePooling1D,
  117. keras.layers.pooling.AveragePooling2D,
  118. keras.layers.pooling.AveragePooling3D,
  119. keras.layers.pooling.GlobalAveragePooling1D,
  120. keras.layers.pooling.GlobalAveragePooling2D,
  121. keras.layers.pooling.GlobalAveragePooling3D,
  122. keras.layers.pooling.GlobalMaxPooling1D,
  123. keras.layers.pooling.GlobalMaxPooling2D,
  124. keras.layers.pooling.GlobalMaxPooling3D,
  125. keras.layers.pooling.MaxPooling1D,
  126. keras.layers.pooling.MaxPooling2D,
  127. keras.layers.pooling.MaxPooling3D,
  128. keras.layers.recurrent.GRU,
  129. keras.layers.recurrent.GRUCell,
  130. keras.layers.recurrent.LSTM,
  131. keras.layers.recurrent.LSTMCell,
  132. keras.layers.recurrent.RNN,
  133. keras.layers.recurrent.SimpleRNN,
  134. keras.layers.recurrent.SimpleRNNCell,
  135. keras.layers.recurrent.StackedRNNCells,
  136. keras.layers.wrappers.Bidirectional,
  137. keras.layers.wrappers.TimeDistributed,
  138. keras.layers.wrappers.Wrapper,
  139. keras.legacy.layers.Highway,
  140. keras.legacy.layers.MaxoutDense,
  141. keras.legacy.layers.Merge,
  142. keras.legacy.layers.Recurrent,
  143. )
  144. return KNOWN_LAYERS
  145. def get_activation_search_safe_layers():
  146. """
  147. Returns a list of keras layer that we can walk along
  148. in an activation search.
  149. """
  150. # Inside function to not break import if Keras changes.
  151. ACTIVATION_SEARCH_SAFE_LAYERS = (
  152. keras.layers.advanced_activations.ELU,
  153. keras.layers.advanced_activations.LeakyReLU,
  154. keras.layers.advanced_activations.PReLU,
  155. keras.layers.advanced_activations.Softmax,
  156. keras.layers.advanced_activations.ThresholdedReLU,
  157. keras.layers.core.Activation,
  158. keras.layers.core.ActivityRegularization,
  159. keras.layers.core.Dropout,
  160. keras.layers.core.Flatten,
  161. keras.layers.core.Reshape,
  162. keras.layers.Add,
  163. keras.layers.noise.GaussianNoise,
  164. keras.layers.normalization.BatchNormalization,
  165. )
  166. return ACTIVATION_SEARCH_SAFE_LAYERS
  167. ###############################################################################
  168. ###############################################################################
  169. ###############################################################################
  170. def contains_activation(layer, activation=None):
  171. """
  172. Check whether the layer contains an activation function.
  173. activation is None then we only check if layer can contain an activation.
  174. """
  175. # todo: add test and check this more throughroughly.
  176. # rely on Keras convention.
  177. if hasattr(layer, "activation"):
  178. if activation is not None:
  179. return layer.activation == keras.activations.get(activation)
  180. else:
  181. return True
  182. elif isinstance(layer, keras.layers.ReLU):
  183. if activation is not None:
  184. return (keras.activations.get("relu") ==
  185. keras.activations.get(activation))
  186. else:
  187. return True
  188. elif isinstance(layer, (
  189. keras.layers.advanced_activations.ELU,
  190. keras.layers.advanced_activations.LeakyReLU,
  191. keras.layers.advanced_activations.PReLU,
  192. keras.layers.advanced_activations.Softmax,
  193. keras.layers.advanced_activations.ThresholdedReLU)):
  194. if activation is not None:
  195. raise Exception("Cannot detect activation type.")
  196. else:
  197. return True
  198. else:
  199. return False
  200. def contains_kernel(layer):
  201. """
  202. Check whether the layer contains a kernel.
  203. """
  204. # TODO: add test and check this more throughroughly.
  205. # rely on Keras convention.
  206. if hasattr(layer, "kernel") or hasattr(layer, "depthwise_kernel") or hasattr(layer, "pointwise_kernel"):
  207. return True
  208. else:
  209. return False
  210. def contains_bias(layer):
  211. """
  212. Check whether the layer contains a bias.
  213. """
  214. # todo: add test and check this more throughroughly.
  215. # rely on Keras convention.
  216. if hasattr(layer, "bias"):
  217. return True
  218. else:
  219. return False
  220. def only_relu_activation(layer):
  221. """Checks if layer contains no or only a ReLU activation."""
  222. return (not contains_activation(layer) or
  223. contains_activation(layer, None) or
  224. contains_activation(layer, "linear") or
  225. contains_activation(layer, "relu"))
  226. def is_network(layer):
  227. """
  228. Is network in network?
  229. """
  230. return isinstance(layer, keras.engine.topology.Network)
  231. def is_conv_layer(layer, *args, **kwargs):
  232. """Checks if layer is a convolutional layer."""
  233. CONV_LAYERS = (
  234. keras.layers.convolutional.Conv1D,
  235. keras.layers.convolutional.Conv2D,
  236. keras.layers.convolutional.Conv2DTranspose,
  237. keras.layers.convolutional.Conv3D,
  238. keras.layers.convolutional.Conv3DTranspose,
  239. keras.layers.convolutional.SeparableConv1D,
  240. keras.layers.convolutional.SeparableConv2D,
  241. keras.layers.convolutional.DepthwiseConv2D
  242. )
  243. return isinstance(layer, CONV_LAYERS)
  244. def is_embedding_layer(layer, *args, **kwargs):
  245. return isinstance(layer, keras.layers.Embedding)
  246. def is_batch_normalization_layer(layer, *args, **kwargs):
  247. """Checks if layer is a batchnorm layer."""
  248. return isinstance(layer, keras.layers.normalization.BatchNormalization)
  249. def is_add_layer(layer, *args, **kwargs):
  250. """Checks if layer is an addition-merge layer."""
  251. return isinstance(layer, keras.layers.Add)
  252. def is_dense_layer(layer, *args, **kwargs):
  253. """Checks if layer is a dense layer."""
  254. return isinstance(layer, keras.layers.core.Dense)
  255. def is_convnet_layer(layer):
  256. """Checks if layer is from a convolutional network."""
  257. # Inside function to not break import if Keras changes.
  258. CONVNET_LAYERS = (
  259. keras.engine.topology.InputLayer,
  260. keras.layers.advanced_activations.ELU,
  261. keras.layers.advanced_activations.LeakyReLU,
  262. keras.layers.advanced_activations.PReLU,
  263. keras.layers.advanced_activations.Softmax,
  264. keras.layers.advanced_activations.ThresholdedReLU,
  265. keras.layers.convolutional.Conv1D,
  266. keras.layers.convolutional.Conv2D,
  267. keras.layers.convolutional.Conv2DTranspose,
  268. keras.layers.convolutional.Conv3D,
  269. keras.layers.convolutional.Conv3DTranspose,
  270. keras.layers.convolutional.Cropping1D,
  271. keras.layers.convolutional.Cropping2D,
  272. keras.layers.convolutional.Cropping3D,
  273. keras.layers.convolutional.SeparableConv1D,
  274. keras.layers.convolutional.SeparableConv2D,
  275. keras.layers.convolutional.UpSampling1D,
  276. keras.layers.convolutional.UpSampling2D,
  277. keras.layers.convolutional.UpSampling3D,
  278. keras.layers.convolutional.ZeroPadding1D,
  279. keras.layers.convolutional.ZeroPadding2D,
  280. keras.layers.convolutional.ZeroPadding3D,
  281. keras.layers.core.Activation,
  282. keras.layers.core.ActivityRegularization,
  283. keras.layers.core.Dense,
  284. keras.layers.core.Dropout,
  285. keras.layers.core.Flatten,
  286. keras.layers.core.Lambda,
  287. keras.layers.core.Masking,
  288. keras.layers.core.Permute,
  289. keras.layers.core.RepeatVector,
  290. keras.layers.core.Reshape,
  291. keras.layers.core.SpatialDropout1D,
  292. keras.layers.core.SpatialDropout2D,
  293. keras.layers.core.SpatialDropout3D,
  294. keras.layers.embeddings.Embedding,
  295. keras.layers.local.LocallyConnected1D,
  296. keras.layers.local.LocallyConnected2D,
  297. keras.layers.Add,
  298. keras.layers.Average,
  299. keras.layers.Concatenate,
  300. keras.layers.Dot,
  301. keras.layers.Maximum,
  302. keras.layers.Minimum,
  303. keras.layers.Multiply,
  304. keras.layers.Subtract,
  305. keras.layers.noise.AlphaDropout,
  306. keras.layers.noise.GaussianDropout,
  307. keras.layers.noise.GaussianNoise,
  308. keras.layers.normalization.BatchNormalization,
  309. keras.layers.pooling.AveragePooling1D,
  310. keras.layers.pooling.AveragePooling2D,
  311. keras.layers.pooling.AveragePooling3D,
  312. keras.layers.pooling.GlobalAveragePooling1D,
  313. keras.layers.pooling.GlobalAveragePooling2D,
  314. keras.layers.pooling.GlobalAveragePooling3D,
  315. keras.layers.pooling.GlobalMaxPooling1D,
  316. keras.layers.pooling.GlobalMaxPooling2D,
  317. keras.layers.pooling.GlobalMaxPooling3D,
  318. keras.layers.pooling.MaxPooling1D,
  319. keras.layers.pooling.MaxPooling2D,
  320. keras.layers.pooling.MaxPooling3D,
  321. )
  322. return isinstance(layer, CONVNET_LAYERS)
  323. def is_relu_convnet_layer(layer):
  324. """Checks if layer is from a convolutional network with ReLUs."""
  325. return (is_convnet_layer(layer) and only_relu_activation(layer))
  326. def is_average_pooling(layer):
  327. """Checks if layer is an average-pooling layer."""
  328. AVERAGEPOOLING_LAYERS = (
  329. keras.layers.pooling.AveragePooling1D,
  330. keras.layers.pooling.AveragePooling2D,
  331. keras.layers.pooling.AveragePooling3D,
  332. keras.layers.pooling.GlobalAveragePooling1D,
  333. keras.layers.pooling.GlobalAveragePooling2D,
  334. keras.layers.pooling.GlobalAveragePooling3D,
  335. )
  336. return isinstance(layer, AVERAGEPOOLING_LAYERS)
  337. def is_max_pooling(layer):
  338. """Checks if layer is a max-pooling layer."""
  339. MAXPOOLING_LAYERS = (
  340. keras.layers.pooling.MaxPooling1D,
  341. keras.layers.pooling.MaxPooling2D,
  342. keras.layers.pooling.MaxPooling3D,
  343. keras.layers.pooling.GlobalMaxPooling1D,
  344. keras.layers.pooling.GlobalMaxPooling2D,
  345. keras.layers.pooling.GlobalMaxPooling3D,
  346. )
  347. return isinstance(layer, MAXPOOLING_LAYERS)
  348. def is_input_layer(layer, ignore_reshape_layers=True):
  349. """Checks if layer is an input layer."""
  350. # Triggers if ALL inputs of layer are connected
  351. # to a Keras input layer object.
  352. # Note: In the sequential api the Sequential object
  353. # adds the Input layer if the user does not.
  354. kgraph = get_kgraph()
  355. layer_inputs = kgraph.get_input_layers(layer)
  356. # We ignore certain layers, that do not modify
  357. # the data content.
  358. # todo: update this list!
  359. IGNORED_LAYERS = (
  360. keras.layers.Flatten,
  361. keras.layers.Permute,
  362. keras.layers.Reshape,
  363. )
  364. while any([isinstance(x, IGNORED_LAYERS) for x in layer_inputs]):
  365. tmp = set()
  366. for l in layer_inputs:
  367. if(ignore_reshape_layers and
  368. isinstance(l, IGNORED_LAYERS)):
  369. tmp.update(kgraph.get_input_layers(l))
  370. else:
  371. tmp.add(l)
  372. layer_inputs = tmp
  373. if all([isinstance(x, keras.layers.InputLayer)
  374. for x in layer_inputs]):
  375. return True
  376. else:
  377. return False
  378. def is_layer_at_idx(layer, index, ignore_reshape_layers=True):
  379. """Checks if layer is a layer at index index, by repeatedly applying is_input_layer()."""
  380. kgraph = get_kgraph()