datareader.py 602 B

123456789101112131415161718192021
  1. import numpy as np
  2. import torch.utils.data as tdata
  3. import os
  4. class DataReader(tdata.Dataset):
  5. def __init__(self, main_path_to_data, data_info):
  6. super(DataReader, self).__init__()
  7. self.data = data_info
  8. self.num_sample = len(self.data)
  9. self.main_path_to_data = main_path_to_data
  10. def __len__(self):
  11. return self.num_sample
  12. def __getitem__(self, n):
  13. filename, label = self.data[n]
  14. path_to_file = os.path.join(self.main_path_to_data, filename)
  15. img = np.load(path_to_file)
  16. return img, np.float32([label])