import torch from torchvision.models import resnet18, ResNet18_Weights import torch.nn as nn class ModelCT(nn.Module): def __init__(self): super(ModelCT, self).__init__() self.backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.convolution2d = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=True) self.fc_maxpool = nn.AdaptiveMaxPool2d((1, 1)) def forward(self, x): x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) x = self.convolution2d(x) x = self.fc_maxpool(x) x = torch.flatten(x, 1) return x