mnist.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """Example applications for image classifcation.
  2. Each function returns a pretrained MNIST model.
  3. The models are based on https://doi.org/10.1371/journal.pone.0130140
  4. and http://jmlr.org/papers/v17/15-618.html.
  5. """
  6. # TODO: rename in, sm_out, out to input_tensors, output_tensors,
  7. # TODO: softmax_output_tenors
  8. # Get Python six functionality:
  9. from __future__ import\
  10. absolute_import, print_function, division, unicode_literals
  11. ###############################################################################
  12. ###############################################################################
  13. ###############################################################################
  14. import os
  15. import keras.utils.data_utils
  16. import numpy as np
  17. import keras.models
  18. from keras.models import load_model, clone_model
  19. #from keras.utils import get_file
  20. __all__ = [
  21. "pretrained_plos_long_relu",
  22. "pretrained_plos_short_relu",
  23. "pretrained_plos_long_tanh",
  24. "pretrained_plos_short_tanh",
  25. ]
  26. ###############################################################################
  27. ###############################################################################
  28. ###############################################################################
  29. # pre-trained models from [https://doi.org/10.1371/journal.pone.0130140 , http://jmlr.org/papers/v17/15-618.html]
  30. PRETRAINED_MODELS = {"pretrained_plos_long_relu":
  31. {"file":"plos-mnist-rect-long.h5",
  32. "url" : "https://www.dropbox.com/s/26w7i58qqcuosn4/plos-mnist-rect-long.h5"
  33. },
  34. "pretrained_plos_short_relu":
  35. {"file":"plos-mnist-rect-short.h5",
  36. "url":"https://www.dropbox.com/s/89nvwyls55xycmw/plos-mnist-rect-short.h5"
  37. },
  38. "pretrained_plos_long_tanh":
  39. {"file":"plos-mnist-tanh-long.h5",
  40. "url":"https://www.dropbox.com/s/61e3a4gdbjo9bca/plos-mnist-tanh-long.h5"
  41. },
  42. "pretrained_plos_short_tanh":
  43. {"file":"plos-mnist-tanh-short.h5",
  44. "url":"https://www.dropbox.com/s/foqv60kot0retfr/plos-mnist-tanh-short.h5"
  45. },
  46. }
  47. def _load_pretrained_net(modelname, new_input_shape):
  48. filename = PRETRAINED_MODELS[modelname]["file"]
  49. urlname = PRETRAINED_MODELS[modelname]["url"]
  50. #model_path = get_file(fname=filename, origin=urlname) #TODO: FIX! corrupts the file?
  51. model_path = os.path.expanduser('~') + "/.keras/models/" + filename
  52. #workaround the more elegant, but dysfunctional solution.
  53. if not os.path.isfile(model_path):
  54. model_dir = os.path.dirname(model_path)
  55. if not os.path.isdir(model_dir):
  56. os.makedirs(model_dir)
  57. os.system("wget {} && mv -v {} {}".format(urlname, filename, model_path))
  58. model = load_model(model_path)
  59. #create replacement input layer with new shape.
  60. model.layers[0] = keras.layers.InputLayer(input_shape=new_input_shape, name="input_1")
  61. for l in model.layers:
  62. l.name = "%s_workaround" % l.name
  63. model = keras.models.Sequential(layers=model.layers)
  64. model_w_sm = clone_model(model)
  65. #NOTE: perform forward pass to fix a keras 2.2.0 related issue with improper weight initialization
  66. #See: https://github.com/albermax/innvestigate/issues/88
  67. x_dummy = np.zeros(new_input_shape)[None, ...]
  68. model_w_sm.predict(x_dummy)
  69. model_w_sm.set_weights(model.get_weights())
  70. model_w_sm.add(keras.layers.Activation("softmax"))
  71. return model, model_w_sm
  72. def pretrained_plos_long_relu(input_shape, **kwargs):
  73. return _load_pretrained_net("pretrained_plos_long_relu", input_shape)
  74. def pretrained_plos_short_relu(input_shape, **kwargs):
  75. return _load_pretrained_net("pretrained_plos_short_relu", input_shape)
  76. def pretrained_plos_long_tanh(input_shape, **kwargs):
  77. return _load_pretrained_net("pretrained_plos_long_tanh", input_shape)
  78. def pretrained_plos_short_tanh(input_shape, **kwargs):
  79. return _load_pretrained_net("pretrained_plos_short_tanh", input_shape)