fd_import.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python
  2. import os, sys, csv, argparse, string
  3. import numpy as np
  4. from itertools import ifilter, imap
  5. from numpy.lib.format import open_memmap
  6. try:
  7. import ROOT
  8. # Map ROOT leaf types to numpy types
  9. typemap = {
  10. ROOT.TLeafB : np.int8, # an 8 bit signed integer (Char_t)
  11. ROOT.TLeafS : np.int16, # a 16 bit signed integer (Short_t)
  12. ROOT.TLeafI : np.int32, # a 32 bit signed integer (Int_t)
  13. ROOT.TLeafF : np.float32, # a 32 bit floating point (Float_t)
  14. ROOT.TLeafD : np.float64, # a 64 bit floating point (Double_t)
  15. ROOT.TLeafL : np.int64, # a 64 bit signed integer (Long64_t)
  16. ROOT.TLeafO : np.bool, # [the letter o, not a zero] a boolean (Bool_t)
  17. }
  18. '''
  19. ROOT.TLeafC : np.string, # a character string terminated by the 0 character
  20. ROOT.TLeafb : np.uint8, # an 8 bit unsigned integer (UChar_t)
  21. ROOT.TLeafs : np.uint16, # a 16 bit unsigned integer (UShort_t)
  22. ROOT.TLeafi : np.uint32, # a 32 bit unsigned integer (UInt_t)
  23. ROOT.TLeafl : np.uint64, # a 64 bit unsigned integer (ULong64_t)
  24. '''
  25. hasROOT = True
  26. except:
  27. print "WARNING: Couldn't import ROOT; any input NTuples will be ignored."
  28. hasROOT = False
  29. #
  30. def mkCache(base, name, shape, **kwargs):
  31. opath = os.path.join(base, "Data", name)
  32. try:
  33. os.makedirs(opath)
  34. except OSError:
  35. pass
  36. return { k : open_memmap(os.path.join(opath, k + ".npy"), dtype=t, mode='w+', shape=shape) for k, t in kwargs.items() }
  37. ### For CSV import.
  38. def getReader(csvfile):
  39. csvFilt = ifilter(lambda x: x[0] != ';', csvfile)
  40. csvFilt = imap(string.strip, csvFilt)
  41. header = [h.strip() for h in csvFilt.next().split(',') if h.strip() != '']
  42. reader = csv.DictReader(csvFilt, fieldnames=header)
  43. return header, reader
  44. def importCSV(base, name, fname, treeName):
  45. with open(fname, 'rb') as csvfile:
  46. head, read = getReader(csvfile)
  47. types = { k : np.double for k in head }
  48. nrow = 0
  49. n = 0
  50. # Count lines and create mmap'ed arrays.
  51. for _ in read:
  52. if nrow % 10000 == 0:
  53. print "\r Counting... %d" % nrow,
  54. sys.stdout.flush()
  55. nrow += 1
  56. Sample = mkCache(base, name, (nrow,), **types)
  57. # Reset to the first row and parse entries.
  58. csvfile.seek(0); csvfile.next()
  59. for row in read:
  60. for k in head:
  61. Sample[k][n] = float(row[k])
  62. n += 1
  63. if n % 10000 == 0 or n == nrow:
  64. print "\r Reading... % 8d / % 8d" % (n, nrow),
  65. sys.stdout.flush()
  66. print
  67. ### For ROOT import.
  68. def importROOT(base, name, fname, treeName):
  69. f = ROOT.TFile.Open(fname)
  70. t = getattr(f, treeName)
  71. bnames = [ b.GetName() for b in t.GetListOfBranches() ]
  72. types = { b.GetName() : typemap[type(b.GetLeaf(b.GetName()))] for b in t.GetListOfBranches() }
  73. nEvt = t.GetEntries()
  74. n = 0
  75. Sample = mkCache(base, name, (nEvt,), **types)
  76. for n, event in enumerate(t):
  77. for bname in bnames:
  78. Sample[bname][n] = getattr(event, bname)
  79. n += 1
  80. if n % 10000 == 0 or n == nEvt:
  81. print "\r Reading... % 8d / % 8d" % (n, nEvt),
  82. sys.stdout.flush()
  83. f.Close()
  84. ###### Map file types to importers.
  85. mapping = {
  86. 'csv' : importCSV,
  87. 'root' : importROOT if hasROOT else None,
  88. 'root.1' : importROOT if hasROOT else None,
  89. }
  90. ###### Okay, start it up.
  91. ArgP = argparse.ArgumentParser(description=' === Functional Decomposition Importer ===')
  92. ArgP.add_argument('--base', type=str, default=".", help="FD base directory.")
  93. ArgP.add_argument('--tree', type=str, help="Name of tree to import (ROOT files only).")
  94. ArgP.add_argument('files', default=[], nargs='*', help="List of files to import.")
  95. ArgC = ArgP.parse_args()
  96. ipath = os.path.join(ArgC.base, "Input")
  97. # Make list of input files
  98. if len(ArgC.files) > 0:
  99. fpath = ArgC.files
  100. files = [ os.path.basename(x) for x in fpath ]
  101. else:
  102. files = os.listdir(ipath)
  103. fpath = [ os.path.join(ipath, x) for x in l ]
  104. # And read them in
  105. for fname in fpath:
  106. name, ext = os.path.basename(fname).split(os.extsep, 1)
  107. print name, ext
  108. try:
  109. func = mapping[ext]
  110. except KeyError:
  111. print " WARNING: Skipping file with unrecognized extension."
  112. continue
  113. if func is not None:
  114. func(ArgC.base, name, fname, ArgC.tree)
  115. else:
  116. print " WARNING: Skipping disabled filetype."