__init__.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Get Python six functionality:
  2. from __future__ import\
  3. absolute_import, print_function, division, unicode_literals
  4. ###############################################################################
  5. ###############################################################################
  6. ###############################################################################
  7. from .base import NotAnalyzeableModelException
  8. from .deeplift import DeepLIFT
  9. from .deeplift import DeepLIFTWrapper
  10. from .gradient_based import BaselineGradient
  11. from .gradient_based import Gradient
  12. from .gradient_based import InputTimesGradient
  13. from .gradient_based import GuidedBackprop
  14. from .gradient_based import Deconvnet
  15. from .gradient_based import IntegratedGradients
  16. from .gradient_based import SmoothGrad
  17. from .misc import Input
  18. from .misc import Random
  19. from .pattern_based import PatternNet
  20. from .pattern_based import PatternAttribution
  21. from .relevance_based.relevance_analyzer import BaselineLRPZ
  22. from .relevance_based.relevance_analyzer import LRP
  23. from .relevance_based.relevance_analyzer import LRPZ
  24. from .relevance_based.relevance_analyzer import LRPZIgnoreBias
  25. from .relevance_based.relevance_analyzer import LRPZPlus
  26. from .relevance_based.relevance_analyzer import LRPZPlusFast
  27. from .relevance_based.relevance_analyzer import LRPEpsilon
  28. from .relevance_based.relevance_analyzer import LRPEpsilonIgnoreBias
  29. from .relevance_based.relevance_analyzer import LRPWSquare
  30. from .relevance_based.relevance_analyzer import LRPFlat
  31. from .relevance_based.relevance_analyzer import LRPAlphaBeta
  32. from .relevance_based.relevance_analyzer import LRPAlpha2Beta1
  33. from .relevance_based.relevance_analyzer import LRPAlpha2Beta1IgnoreBias
  34. from .relevance_based.relevance_analyzer import LRPAlpha1Beta0
  35. from .relevance_based.relevance_analyzer import LRPAlpha1Beta0IgnoreBias
  36. from .relevance_based.relevance_analyzer import LRPSequentialPresetA
  37. from .relevance_based.relevance_analyzer import LRPSequentialPresetB
  38. from .relevance_based.relevance_analyzer import LRPSequentialPresetAFlat
  39. from .relevance_based.relevance_analyzer import LRPSequentialPresetBFlat
  40. from .relevance_based.relevance_analyzer import LRPSequentialPresetBFlatUntilIdx
  41. from .deeptaylor import DeepTaylor
  42. from .deeptaylor import BoundedDeepTaylor
  43. from .wrapper import WrapperBase
  44. from .wrapper import AugmentReduceBase
  45. from .wrapper import GaussianSmoother
  46. from .wrapper import PathIntegrator
  47. # Disable pyflaks warnings:
  48. assert NotAnalyzeableModelException
  49. assert DeepLIFT
  50. assert BaselineLRPZ
  51. assert WrapperBase
  52. assert AugmentReduceBase
  53. assert GaussianSmoother
  54. assert PathIntegrator
  55. ###############################################################################
  56. ###############################################################################
  57. ###############################################################################
  58. analyzers = {
  59. # Utility.
  60. "input": Input,
  61. "random": Random,
  62. # Gradient based
  63. "gradient": Gradient,
  64. "gradient.baseline": BaselineGradient,
  65. "input_t_gradient": InputTimesGradient,
  66. "deconvnet": Deconvnet,
  67. "guided_backprop": GuidedBackprop,
  68. "integrated_gradients": IntegratedGradients,
  69. "smoothgrad": SmoothGrad,
  70. # Relevance based
  71. "lrp": LRP,
  72. "lrp.z": LRPZ,
  73. "lrp.z_IB": LRPZIgnoreBias,
  74. "lrp.epsilon": LRPEpsilon,
  75. "lrp.epsilon_IB": LRPEpsilonIgnoreBias,
  76. "lrp.w_square": LRPWSquare,
  77. "lrp.flat": LRPFlat,
  78. "lrp.alpha_beta": LRPAlphaBeta,
  79. "lrp.alpha_2_beta_1": LRPAlpha2Beta1,
  80. "lrp.alpha_2_beta_1_IB": LRPAlpha2Beta1IgnoreBias,
  81. "lrp.alpha_1_beta_0": LRPAlpha1Beta0,
  82. "lrp.alpha_1_beta_0_IB": LRPAlpha1Beta0IgnoreBias,
  83. "lrp.z_plus": LRPZPlus,
  84. "lrp.z_plus_fast": LRPZPlusFast,
  85. "lrp.sequential_preset_a": LRPSequentialPresetA,
  86. "lrp.sequential_preset_b": LRPSequentialPresetB,
  87. "lrp.sequential_preset_a_flat": LRPSequentialPresetAFlat,
  88. "lrp.sequential_preset_b_flat": LRPSequentialPresetBFlat,
  89. "lrp.sequential_preset_b_flat_until_idx": LRPSequentialPresetBFlatUntilIdx,
  90. # Deep Taylor
  91. "deep_taylor": DeepTaylor,
  92. "deep_taylor.bounded": BoundedDeepTaylor,
  93. # DeepLIFT
  94. #"deep_lift": DeepLIFT,
  95. "deep_lift.wrapper": DeepLIFTWrapper,
  96. # Pattern based
  97. "pattern.net": PatternNet,
  98. "pattern.attr": PatternAttribution,
  99. }
  100. def create_analyzer(name, model, **kwargs):
  101. """Instantiates the analyzer with the name 'name'
  102. This convenience function takes an analyzer name
  103. creates the respective analyzer.
  104. Alternatively analyzers can be created directly by
  105. instantiating the respective classes.
  106. :param name: Name of the analyzer.
  107. :param model: The model to analyze, passed to the analyzer's __init__.
  108. :param kwargs: Additional parameters for the analyzer's .
  109. :return: An instance of the chosen analyzer.
  110. :raise KeyError: If there is no analyzer with the passed name.
  111. """
  112. try:
  113. analyzer_class = analyzers[name]
  114. except KeyError:
  115. raise KeyError(
  116. "No analyzer with the name '%s' could be found."
  117. " All possible names are: %s" % (name, list(analyzers.keys())))
  118. return analyzer_class(model, **kwargs)