augmentation.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. from keras.preprocessing.image import Iterator
  3. class CustomIterator(Iterator):
  4. def __init__(self, data, batch_size=6, shuffle=False, seed=None,
  5. dim_ordering='tf'):
  6. self.mri_data, self.jac_data, self.xls_data, self.labels, self.ptid, self.imageID, self.confid, self.csf = data
  7. self.dim_ordering = dim_ordering
  8. self.batch_size = batch_size
  9. super(CustomIterator, self).__init__(self.mri_data.shape[0], batch_size, shuffle, seed)
  10. def _get_batches_of_transformed_samples(self, index_array):
  11. batch_mri = np.zeros(tuple([len(index_array)] + list(self.mri_data.shape[1:])))
  12. batch_jac = np.zeros(tuple([len(index_array)] + list(self.jac_data.shape[1:])))
  13. batch_xls = np.zeros(tuple([len(index_array)] + list(self.xls_data.shape[1:])))
  14. batch_labels = np.zeros(tuple([len(index_array)]),dtype=object)
  15. for i, j in enumerate(index_array):
  16. mri = self.mri_data[j]
  17. jac = self.jac_data[j]
  18. xls = self.xls_data[j]
  19. batch_mri[i]= mri
  20. batch_jac[i]= jac
  21. batch_xls[i]= xls
  22. batch_labels[i] = self.labels[j]
  23. # print('batch label'+str(index_array)+":")
  24. # print(batch_labels)
  25. # print('batch mri'+str(index_array)+":")
  26. # print(batch_mri)
  27. return [batch_mri, batch_jac, batch_xls], batch_labels
  28. def next(self):
  29. with self.lock:
  30. index_array = next(self.index_generator)
  31. return self._get_batches_of_transformed_samples(index_array)