system.py 296 B

12345678910
  1. import torch
  2. # Forces torch to initialize cuDNN
  3. # From StackOverflow https://stackoverflow.com/questions/66588715
  4. def force_init_cudnn(dev=torch.device("cuda:0")):
  5. s = 32
  6. torch.nn.functional.conv2d(
  7. torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev)
  8. )