CNN_methods.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from torch import add
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. """
  6. Returns a function that convolutes or separable convolutes, normalizes, activates (ELU), pools and dropouts input.
  7. Kernel_size = (height, width, depth)
  8. CAN DO SEPARABLE CONVOLUTION IF GROUP = 2!!!! :))))
  9. """
  10. def conv_elu_maxpool_drop(in_channel, filters, kernel_size, stride=(1,1,1), padding=0, dilation=1,
  11. groups=1, bias=True, padding_mode='zeros', pool=False, drop_rate=0, sep_conv = False):
  12. def f(input):
  13. # SEPARABLE CONVOLUTION
  14. if(sep_conv):
  15. # SepConv depthwise, Normalizes, and ELU activates
  16. sepConvDepthwise = nn.Conv3d(in_channel, filters, kernel_size, stride=stride, padding=padding,
  17. groups=in_channel, bias=bias, padding_mode=padding_mode)(input)
  18. # SepConv pointwise
  19. # Todo, will stride & padding be correct for this?
  20. conv = nn.Conv3d(in_channel, filters, kernel_size=1, stride=stride, padding=padding,
  21. groups=1, bias=bias, padding_mode=padding_mode)(sepConvDepthwise)
  22. # CONVOLUTES
  23. else:
  24. # Convolutes, Normalizes, and ELU activates
  25. conv = nn.Conv3d(in_channel, filters, kernel_size, stride=stride, padding=padding, dilation=dilation,
  26. groups=groups, bias=bias, padding_mode=padding_mode)(input)
  27. normalization = nn.BatchNorm3d(filters)(conv)
  28. elu = nn.ELU()(normalization)
  29. # Pools
  30. if (pool):
  31. elu = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(elu)
  32. return nn.Dropout(p=drop_rate)(elu)
  33. return f
  34. '''
  35. Mid_flow in CNN. sep_convolutes 3 times, adds residual (initial input) to 3 times convoluted, and activates through ELU()
  36. '''
  37. def mid_flow(I, drop_rate, filters):
  38. in_channel = None # TODO, IN_CHANNEL
  39. residual = I # TODO, DOES THIS ACTUALLY COPY?
  40. x = conv_elu_maxpool_drop(in_channel, filters, (3,3,3), drop_rate=drop_rate)(I)
  41. x = conv_elu_maxpool_drop(in_channel, filters, (3,3,3), drop_rate=drop_rate)(x)
  42. x = conv_elu_maxpool_drop(in_channel, filters, (3, 3, 3), drop_rate=drop_rate)(x)
  43. x = add(x, residual)
  44. x = nn.ELU()(x)
  45. return x
  46. """
  47. Returns a function that Fully Connects (FC), normalizes, activates (ELU), and dropouts input.
  48. """
  49. def fc_elu_drop(in_features, units, drop_rate=0):
  50. def f(input):
  51. fc = nn.Linear(in_features, out_features=units)(input)
  52. fc = nn.BatchNorm3d(units)(fc) # TODO 3d or 2d???
  53. fc = nn.ELU()(fc)
  54. fc = nn.Dropout(p=drop_rate)
  55. return fc
  56. return f