소스 검색

pretrained got changed to weights in torch 1.0+

Zan Klanecek 1 년 전
부모
커밋
a5eb890ba4
1개의 변경된 파일2개의 추가작업 그리고 2개의 파일을 삭제
  1. 2 2
      5-naloga-cnn-klasifikacija-covid-slik/model.py

+ 2 - 2
5-naloga-cnn-klasifikacija-covid-slik/model.py

@@ -1,11 +1,11 @@
 import torch
-import torchvision.models as tmodels
+from torchvision.models import resnet18, ResNet18_Weights
 import torch.nn as nn
 
 class ModelCT(nn.Module):
     def __init__(self):
         super(ModelCT, self).__init__()
-        self.backbone = tmodels.resnet18(pretrained=True)
+        self.backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
         self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
         self.convolution2d = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=True)
         self.fc_maxpool = nn.AdaptiveMaxPool2d((1, 1))