model.py 924 B

1234567891011121314151617181920212223242526
  1. import torch
  2. import torchvision.models as tmodels
  3. import torch.nn as nn
  4. class ModelCT(nn.Module):
  5. def __init__(self):
  6. super(ModelCT, self).__init__()
  7. self.backbone = tmodels.resnet18(pretrained=True)
  8. self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  9. self.convolution2d = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=True)
  10. self.fc_maxpool = nn.AdaptiveMaxPool2d((1, 1))
  11. def forward(self, x):
  12. x = self.backbone.conv1(x)
  13. x = self.backbone.bn1(x)
  14. x = self.backbone.relu(x)
  15. x = self.backbone.maxpool(x)
  16. x = self.backbone.layer1(x)
  17. x = self.backbone.layer2(x)
  18. x = self.backbone.layer3(x)
  19. x = self.backbone.layer4(x)
  20. x = self.convolution2d(x)
  21. x = self.fc_maxpool(x)
  22. x = torch.flatten(x, 1)
  23. return x