import torch # from torch import add import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import copy class Conv_elu_maxpool_drop(nn.Module): def __init__(self, input_size, output_size, kernel_size, prps, stride=(1,1,1), pool = False, sep_conv = False, padding = 0): super(Conv_elu_maxpool_drop, self).__init__() self.input_size = input_size self.output_size = output_size self.pool_status = pool self.sep_conv_status = sep_conv # LAYERS # TODO Check here, how many groups? just 2? or groups=input_size? if(self.sep_conv_status): self.sepConvDepthwise = nn.Conv3d(input_size, output_size, kernel_size=kernel_size, stride=stride, padding=padding, dilation=prps['dilation'], groups=2, bias=prps["bias"], padding_mode=prps["padding_mode"]) self.conv = nn.Conv3d(input_size, output_size, kernel_size=kernel_size, stride=stride, padding=padding, groups=1, bias=prps["bias"], padding_mode=prps["padding_mode"]) self.normalization = nn.BatchNorm3d(output_size) self.elu = nn.ELU() self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=0) self.dropout = nn.Dropout(p=prps['drop_rate']) self.weight = nn.Parameter(torch.randn(input_size, output_size)) self.bias = nn.Parameter(torch.randn(output_size)) def forward(self, x): # print(f"Forward Input: {x.size()}") if(self.sep_conv_status): x = self.sepConvDepthwise(x) else: x = self.conv(x) x = self.normalization(x) x = self.elu(x) if(self.pool_status): self.maxpool(x) x = self.dropout(x) # return torch.matmul(x, self.weight) + self.bias return x # TODO WHAT??? WEIGHT & BIAS YES OR NO? class Mid_flow(nn.Module): def __init__(self, input_size, output_size, prps): super(Mid_flow, self).__init__() self.input_size = input_size self.output_size = output_size # LAYERS self.conv = Conv_elu_maxpool_drop(input_size, output_size, kernel_size=(3,3,3), stride=(1,1,1), sep_conv=True, padding='same', prps=prps) self.elu = nn.ELU() self.weight = nn.Parameter(torch.randn(input_size, output_size)) self.bias = nn.Parameter(torch.randn(output_size)) def forward(self, x): # print("AT MIDFLOW!") residual = x.clone() # print(f"Input: {x.size()}") x = self.conv(x) x = self.conv(x) x = self.conv(x) # print(f"Output: {x.size()}") x = torch.add(x, residual) x = self.elu(x) # return torch.matmul(x, self.weight) + self.bias # TODO WHAT??? WEIGHT & BIAS YES OR NO? return x class Fc_elu_drop(nn.Module): def __init__(self, input_size, output_size, softmax, prps): super(Fc_elu_drop, self).__init__() self.input_size = input_size self.output_size = output_size # LAYERS self.linear = nn.Linear(input_size, output_size) self.normalization = nn.BatchNorm1d(output_size) self.elu = nn.ELU() self.dropout = nn.Dropout(p=prps['drop_rate']) self.softmax_status = softmax if(softmax): self.softmax = nn.Softmax() self.weight = nn.Parameter(torch.randn(input_size, output_size)) self.bias = nn.Parameter(torch.randn(output_size)) def forward(self, x): # print("AT FC") # print(f"Forward Input: {x.size()}") x = self.linear(x) # print(f"After Linear: {x.size()}") x = self.normalization(x) x = self.elu(x) x = self.dropout(x) if(self.softmax_status): x = self.softmax(x) # return torch.matmul(x, self.weight) + self.bias return x # TODO WHAT??? WEIGHT & BIAS YES OR NO?