"""Example applications for image classifcation. Each function returns a pretrained MNIST model. The models are based on https://doi.org/10.1371/journal.pone.0130140 and http://jmlr.org/papers/v17/15-618.html. """ # TODO: rename in, sm_out, out to input_tensors, output_tensors, # TODO: softmax_output_tenors # Get Python six functionality: from __future__ import\ absolute_import, print_function, division, unicode_literals ############################################################################### ############################################################################### ############################################################################### import os import keras.utils.data_utils import numpy as np import keras.models from keras.models import load_model, clone_model #from keras.utils import get_file __all__ = [ "pretrained_plos_long_relu", "pretrained_plos_short_relu", "pretrained_plos_long_tanh", "pretrained_plos_short_tanh", ] ############################################################################### ############################################################################### ############################################################################### # pre-trained models from [https://doi.org/10.1371/journal.pone.0130140 , http://jmlr.org/papers/v17/15-618.html] PRETRAINED_MODELS = {"pretrained_plos_long_relu": {"file":"plos-mnist-rect-long.h5", "url" : "https://www.dropbox.com/s/26w7i58qqcuosn4/plos-mnist-rect-long.h5" }, "pretrained_plos_short_relu": {"file":"plos-mnist-rect-short.h5", "url":"https://www.dropbox.com/s/89nvwyls55xycmw/plos-mnist-rect-short.h5" }, "pretrained_plos_long_tanh": {"file":"plos-mnist-tanh-long.h5", "url":"https://www.dropbox.com/s/61e3a4gdbjo9bca/plos-mnist-tanh-long.h5" }, "pretrained_plos_short_tanh": {"file":"plos-mnist-tanh-short.h5", "url":"https://www.dropbox.com/s/foqv60kot0retfr/plos-mnist-tanh-short.h5" }, } def _load_pretrained_net(modelname, new_input_shape): filename = PRETRAINED_MODELS[modelname]["file"] urlname = PRETRAINED_MODELS[modelname]["url"] #model_path = get_file(fname=filename, origin=urlname) #TODO: FIX! corrupts the file? model_path = os.path.expanduser('~') + "/.keras/models/" + filename #workaround the more elegant, but dysfunctional solution. if not os.path.isfile(model_path): model_dir = os.path.dirname(model_path) if not os.path.isdir(model_dir): os.makedirs(model_dir) os.system("wget {} && mv -v {} {}".format(urlname, filename, model_path)) model = load_model(model_path) #create replacement input layer with new shape. model.layers[0] = keras.layers.InputLayer(input_shape=new_input_shape, name="input_1") for l in model.layers: l.name = "%s_workaround" % l.name model = keras.models.Sequential(layers=model.layers) model_w_sm = clone_model(model) #NOTE: perform forward pass to fix a keras 2.2.0 related issue with improper weight initialization #See: https://github.com/albermax/innvestigate/issues/88 x_dummy = np.zeros(new_input_shape)[None, ...] model_w_sm.predict(x_dummy) model_w_sm.set_weights(model.get_weights()) model_w_sm.add(keras.layers.Activation("softmax")) return model, model_w_sm def pretrained_plos_long_relu(input_shape, **kwargs): return _load_pretrained_net("pretrained_plos_long_relu", input_shape) def pretrained_plos_short_relu(input_shape, **kwargs): return _load_pretrained_net("pretrained_plos_short_relu", input_shape) def pretrained_plos_long_tanh(input_shape, **kwargs): return _load_pretrained_net("pretrained_plos_long_tanh", input_shape) def pretrained_plos_short_tanh(input_shape, **kwargs): return _load_pretrained_net("pretrained_plos_short_tanh", input_shape)