test_pattern.py 12 KB


  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. import pytest
  8. from keras.datasets import mnist
  9. import keras.layers
  10. import keras.models
  11. from keras.models import Model
  12. import keras.optimizers
  13. import numpy as np
  14. import unittest
  15. from innvestigate.utils.tests import dryrun
  16. import innvestigate
  17. from innvestigate.tools import PatternComputer
  18. ###############################################################################
  19. ###############################################################################
  20. ###############################################################################
  21. @pytest.mark.fast
  22. @pytest.mark.precommit
  23. def test_fast__PatternComputer_dummy_parallel():
  24. def method(model):
  25. return PatternComputer(model, pattern_type="dummy",
  26. compute_layers_in_parallel=True)
  27. dryrun.test_pattern_computer(method, "mnist.log_reg")
  28. @pytest.mark.skip("Feature not supported.")
  29. @pytest.mark.fast
  30. @pytest.mark.precommit
  31. def test_fast__PatternComputer_dummy_sequential():
  32. def method(model):
  33. return PatternComputer(model, pattern_type="dummy",
  34. compute_layers_in_parallel=False)
  35. dryrun.test_pattern_computer(method, "mnist.log_reg")
  36. ###############################################################################
  37. ###############################################################################
  38. ###############################################################################
  39. @pytest.mark.fast
  40. @pytest.mark.precommit
  41. def test_fast__PatternComputer_linear():
  42. def method(model):
  43. return PatternComputer(model, pattern_type="linear")
  44. dryrun.test_pattern_computer(method, "mnist.log_reg")
  45. @pytest.mark.precommit
  46. def test_precommit__PatternComputer_linear():
  47. def method(model):
  48. return PatternComputer(model, pattern_type="linear")
  49. dryrun.test_pattern_computer(method, "mnist.*")
  50. @pytest.mark.fast
  51. @pytest.mark.precommit
  52. def test_fast__PatternComputer_relupositive():
  53. def method(model):
  54. return PatternComputer(model, pattern_type="relu.positive")
  55. dryrun.test_pattern_computer(method, "mnist.log_reg")
  56. @pytest.mark.precommit
  57. def test_precommit__PatternComputer_relupositive():
  58. def method(model):
  59. return PatternComputer(model, pattern_type="relu.positive")
  60. dryrun.test_pattern_computer(method, "mnist.*")
  61. @pytest.mark.fast
  62. @pytest.mark.precommit
  63. def test_fast__PatternComputer_relunegative():
  64. def method(model):
  65. return PatternComputer(model, pattern_type="relu.negative")
  66. dryrun.test_pattern_computer(method, "mnist.log_reg")
  67. @pytest.mark.precommit
  68. def test_precommit__PatternComputer_relunegative():
  69. def method(model):
  70. return PatternComputer(model, pattern_type="relu.negative")
  71. dryrun.test_pattern_computer(method, "mnist.*")
  72. ###############################################################################
  73. ###############################################################################
  74. ###############################################################################
  75. @pytest.mark.fast
  76. @pytest.mark.precommit
  77. class HaufePatternExample(unittest.TestCase):
  78. def test(self):
  79. np.random.seed(234354346)
  80. # need many samples to get close to optimum and stable numbers
  81. n = 1000
  82. a_s = np.asarray([1, 0]).reshape((1, 2))
  83. a_d = np.asarray([1, 1]).reshape((1, 2))
  84. y = np.random.uniform(size=(n, 1))
  85. eps = np.random.rand(n, 1)
  86. X = y * a_s + eps * a_d
  87. model = keras.models.Sequential(
  88. [keras.layers.Dense(1, input_shape=(2,), use_bias=True), ]
  89. )
  90. model.compile(optimizer=keras.optimizers.Adam(lr=1), loss="mse")
  91. model.fit(X, y, epochs=20, verbose=0).history
  92. self.assertTrue(model.evaluate(X, y, verbose=0) < 0.05)
  93. pc = PatternComputer(model, pattern_type="linear")
  94. A = pc.compute(X)[0]
  95. W = model.get_weights()[0]
  96. #print(a_d, model.get_weights()[0])
  97. #print(a_s, A)
  98. def allclose(a, b):
  99. return np.allclose(a, b, rtol=0.05, atol=0.05)
  100. # perpendicular to a_d
  101. self.assertTrue(allclose(a_d.ravel(), abs(W.ravel())))
  102. # estimated pattern close to true pattern
  103. self.assertTrue(allclose(a_s.ravel(), A.ravel()))
  104. ###############################################################################
  105. ###############################################################################
  106. ###############################################################################
  107. def fetch_data():
  108. # the data, shuffled and split between train and test sets
  109. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  110. x_train = (x_train.reshape(60000, 1, 28, 28) - 127.5) / 127.5
  111. x_test = (x_test.reshape(10000, 1, 28, 28) - 127.5) / 127.5
  112. x_train = x_train.astype('float32')
  113. x_test = x_test.astype('float32')
  114. return x_train[:100], y_train[:100], x_test[:10], y_test[:10]
  115. def create_model(clazz):
  116. num_classes = 10
  117. network = clazz(
  118. (None, 1, 28, 28),
  119. num_classes,
  120. dense_units=1024,
  121. dropout_rate=0.25)
  122. model_wo_sm = Model(inputs=network["in"], outputs=network["out"])
  123. model_w_sm = Model(inputs=network["in"], outputs=network["sm_out"])
  124. return model_wo_sm, model_w_sm
  125. def train_model(model, data, epochs=20):
  126. batch_size = 128
  127. num_classes = 10
  128. x_train, y_train, x_test, y_test = data
  129. # convert class vectors to binary class matrices
  130. y_train = keras.utils.to_categorical(y_train, num_classes)
  131. y_test = keras.utils.to_categorical(y_test, num_classes)
  132. model.compile(loss='categorical_crossentropy',
  133. optimizer=keras.optimizers.RMSprop(),
  134. metrics=['accuracy'])
  135. model.fit(x_train, y_train,
  136. batch_size=batch_size,
  137. epochs=epochs,
  138. verbose=0)
  139. model.evaluate(x_test, y_test, batch_size=batch_size, verbose=0)
  140. @pytest.mark.fast
  141. @pytest.mark.precommit
  142. class MnistPatternExample_dense_linear(unittest.TestCase):
  143. def test(self):
  144. np.random.seed(234354346)
  145. model_class = innvestigate.utils.tests.networks.base.mlp_2dense
  146. data = fetch_data()
  147. model, modelp = create_model(model_class)
  148. train_model(modelp, data, epochs=10)
  149. model.set_weights(modelp.get_weights())
  150. analyzer = innvestigate.create_analyzer("pattern.net", model,
  151. pattern_type="linear")
  152. analyzer.fit(data[0], batch_size=256, verbose=0)
  153. patterns = analyzer._patterns
  154. W = model.get_weights()[0]
  155. W2D = W.reshape((-1, W.shape[-1]))
  156. X = data[0].reshape((data[0].shape[0], -1))
  157. Y = np.dot(X, W2D)
  158. def safe_divide(a, b):
  159. return a / (b + (b == 0))
  160. mean_x = X.mean(axis=0)
  161. mean_y = Y.mean(axis=0)
  162. mean_xy = np.dot(X.T, Y) / Y.shape[0]
  163. ExEy = mean_x[:, None] * mean_y[None, :]
  164. cov_xy = mean_xy - ExEy
  165. w_cov_xy = np.diag(np.dot(W2D.T, cov_xy))
  166. A = safe_divide(cov_xy, w_cov_xy[None, :])
  167. def allclose(a, b):
  168. return np.allclose(a, b, rtol=0.05, atol=0.05)
  169. #print(A.sum(), patterns[0].sum())
  170. self.assertTrue(allclose(A.ravel(), patterns[0].ravel()))
  171. @pytest.mark.fast
  172. @pytest.mark.precommit
  173. class MnistPatternExample_dense_relu(unittest.TestCase):
  174. def test(self):
  175. np.random.seed(234354346)
  176. model_class = innvestigate.utils.tests.networks.base.mlp_2dense
  177. data = fetch_data()
  178. model, modelp = create_model(model_class)
  179. train_model(modelp, data, epochs=10)
  180. model.set_weights(modelp.get_weights())
  181. analyzer = innvestigate.create_analyzer("pattern.net", model,
  182. pattern_type="relu")
  183. analyzer.fit(data[0], batch_size=256, verbose=0)
  184. patterns = analyzer._patterns
  185. W, b = model.get_weights()[:2]
  186. W2D = W.reshape((-1, W.shape[-1]))
  187. X = data[0].reshape((data[0].shape[0], -1))
  188. Y = np.dot(X, W2D)
  189. mask = np.dot(X, W2D) + b > 0
  190. count = mask.sum(axis=0)
  191. def safe_divide(a, b):
  192. return a / (b + (b == 0))
  193. mean_x = safe_divide(np.dot(X.T, mask), count)
  194. mean_y = Y.mean(axis=0)
  195. mean_xy = safe_divide(np.dot(X.T, Y * mask), count)
  196. ExEy = mean_x * mean_y
  197. cov_xy = mean_xy - ExEy
  198. w_cov_xy = np.diag(np.dot(W2D.T, cov_xy))
  199. A = safe_divide(cov_xy, w_cov_xy[None, :])
  200. def allclose(a, b):
  201. return np.allclose(a, b, rtol=0.05, atol=0.05)
  202. #print(A.sum(), patterns[0].sum())
  203. self.assertTrue(allclose(A.ravel(), patterns[0].ravel()))
  204. # def extract_2d_patches(X, conv_layer):
  205. # X_in = X
  206. # kernel_shape = conv_layer.kernel_size
  207. # strides = conv_layer.strides
  208. # rates = conv_layer.dilation_rate
  209. # padding = conv_layer.padding
  210. # assert all([x == 1 for x in rates])
  211. # assert all([x == 3 for x in kernel_shape])
  212. # assert all([x == 1 for x in strides])
  213. # if padding.lower() == "same":
  214. # tmp = np.ones(list(X.shape[:2])+[x+3 for x in X.shape[2:]],
  215. # dtype=X.dtype)
  216. # tmp[:, :, 1:-2, 1:-2] = X
  217. # X = tmp
  218. # out_shape = [int(np.ceil((x-k)/s))
  219. # for x, k, s in zip(X.shape[2:], kernel_shape, strides)]
  220. # n_patches = np.prod(list(X.shape[:2])+out_shape)
  221. # dimensions = X.shape[1]*kernel_shape[0]*kernel_shape[1]
  222. # ret = np.empty((n_patches, dimensions), dtype=X.dtype)
  223. # i_ret = 0
  224. # for j in range(X.shape[2]-kernel_shape[0]):
  225. # for k in range(X.shape[3]-kernel_shape[1]):
  226. # patches = X[:, :, j:j+kernel_shape[0], k:k+kernel_shape[1]]
  227. # patches = patches.reshape((-1, dimensions))
  228. # ret[i_ret:i_ret+X.shape[0]] = patches
  229. # i_ret += X.shape[0]
  230. # if True:
  231. # import tensorflow as tf
  232. # with tf.Session():
  233. # tf_ret = tf.extract_image_patches(
  234. # images=X_in.transpose((0, 2, 3, 1)),
  235. # ksizes=[1, kernel_shape[0], kernel_shape[1], 1],
  236. # strides=[1, strides[0], strides[1], 1],
  237. # rates=[1, rates[0], rates[1], 1],
  238. # padding=padding.upper()).eval()
  239. # tf_ret = tf_ret.reshape((-1, tf_ret.shape[-1]))
  240. # #print(tf_ret.shape, ret.shape)
  241. # assert tf_ret.shape == ret.shape
  242. # #print(tf_ret.mean(), ret.mean())
  243. # assert tf_ret.mean() == ret.mean()
  244. # assert i_ret == n_patches
  245. # return ret
  246. # class __disabled__MnistPatternExample_conv_linear(unittest.TestCase):
  247. # def test(self):
  248. # np.random.seed(234354346)
  249. # K.set_image_data_format("channels_first")
  250. # model_class = innvestigate.utils.tests.networks.base.cnn_2convb_2dense
  251. # data = fetch_data()
  252. # model, modelp = create_model(model_class)
  253. # train_model(modelp, data, epochs=1)
  254. # model.set_weights(modelp.get_weights())
  255. # analyzer = innvestigate.create_analyzer("pattern.net", model)
  256. # analyzer.fit(data[0], pattern_type="linear",
  257. # batch_size=256, verbose=0)
  258. # patterns = analyzer._patterns
  259. # W = model.get_weights()[0]
  260. # W2D = W.reshape((-1, W.shape[-1]))
  261. # X = extract_2d_patches(data[0], model.layers[1])
  262. # Y = np.dot(X, W2D)
  263. # def safe_divide(a, b):
  264. # return a / (b + (b == 0))
  265. # mean_x = X.mean(axis=0)
  266. # mean_y = Y.mean(axis=0)
  267. # mean_xy = np.dot(X.T, Y) / Y.shape[0]
  268. # ExEy = mean_x[:, None] * mean_y[None, :]
  269. # cov_xy = mean_xy - ExEy
  270. # w_cov_xy = np.diag(np.dot(W2D.T, cov_xy))
  271. # A = safe_divide(cov_xy, w_cov_xy[None, :])
  272. # def allclose(a, b):
  273. # return np.allclose(a, b, rtol=0.05, atol=0.05)
  274. # self.assertTrue(allclose(A.ravel(), patterns[0].ravel()))