A set of tools to verify the operation of the j2ms2 and tConvert programs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

341 lines
21 KiB

  1. #!/usr/bin/env python3
  2. #######################/opt/anaconda3/bin/python3
  3. # compare-ms-idi.py
  4. #
  5. # in order to gain confidence that everything from
  6. # a MeasurementSet ended up the FITS-IDI file this
  7. # tool accumulates statistics (currently: exposure time)
  8. # per baseline per source.
  9. # Deviations between source and destination numbers
  10. # might indicate problems.
  11. from __future__ import print_function
  12. from six import with_metaclass # Grrrrr http://python-future.org/compatible_idioms.html#metaclasses
  13. from functools import partial, reduce
  14. import sys, re, collections, glob, os, operator, itertools, astropy.io.fits, numpy, argparse
  15. import pyrap.tables as pt
  16. # everybody should love themselves some function composition. really.
  17. compose = lambda *fns : lambda x: reduce(lambda acc, f: f(acc), reversed(fns), x)
  18. pam = lambda *fns : lambda x: tuple(map(lambda fn: fn(x), fns)) # fns= [f0, f1, ...] => map(x, fns) => (f0(x), f1(x), ...)
  19. Apply = lambda *args : args[0](*args[1:]) # args= (f0, arg0, arg1,...) => f0(arg0, arg1, ...)
  20. identity = lambda x : x
  21. choice = lambda p, t, f : lambda x: t(x) if p(x) else f(x)
  22. const = lambda c : lambda *_: c
  23. method = lambda f : lambda *args, **kwargs: f(*args, **kwargs)
  24. Map = lambda fn : partial(map, fn)
  25. #Group = lambda n : operator.methodcaller('group', n)
  26. GroupBy = lambda keyfn : partial(itertools.groupby, key=keyfn)
  27. Sort = lambda keyfn : partial(sorted, key=keyfn)
  28. Filter = lambda pred : partial(filter, pred)
  29. Reduce = lambda *args : partial(reduce, *args)
  30. ZipStar = lambda x : zip(*x)
  31. Star = lambda f : lambda args: f(*args)
  32. StarMap = lambda f : partial(itertools.starmap, f)
  33. D = lambda x : print(x) or x
  34. DD = lambda pfx : lambda x: print(pfx,":",x) or x
  35. Type = lambda **kwargs : type('', (), kwargs)
  36. Derived = lambda b, **kw : type('', (b,), kw) # Create easy derived type so attributes can be set/added
  37. Obj = lambda **kwargs : Type(**kwargs)()
  38. Append = lambda l, v : l.append(v) or l
  39. SetAttr = lambda o, a, v : setattr(o, a, v) or o
  40. XFormAttr= lambda attr, f, missing=None: lambda obj, *args, **kwargs: SetAttr(obj, attr, f(getattr(obj, attr, missing), *args, **kwargs))
  41. # Thanks to Py3 one must sometimes drain an iterable for its side effects. Thanks guys!
  42. # From https://docs.python.org/2/library/itertools.html#recipes
  43. # consume(), all_equal()
  44. consume = partial(collections.deque, maxlen=0)
  45. all_equal= compose(lambda g: next(g, True) and not next(g, False), itertools.groupby)
  46. # shorthands
  47. GetN = operator.itemgetter
  48. GetA = operator.attrgetter
  49. Repeat = itertools.repeat
  50. Call = operator.methodcaller
  51. # other generic stuff (own invention)
  52. n_times = choice(partial(operator.ne, 1), "{0:6d} times".format, const("once")) # polite integer formatting w/ special case for 1
  53. ##################################################################################
  54. #
  55. # this is the information we accumulate per primary key (whatever that may turn
  56. # out to be)
  57. #
  58. ##################################################################################
  59. SumAttr = lambda a: lambda o0, o1: setattr(o0, a, getattr(o0, a)+getattr(o1, a))
  60. FloatEq = lambda l, r: abs(l-r)<1e-3
  61. Statistics = Type(weight=0, exposure=0, n=0, n_neg=0, source=None,
  62. __repr__ = method(compose(" ".join, Filter(operator.truth),
  63. pam("{0.exposure:9.2f}s wgt={0.weight:9.2f}".format,
  64. compose(choice(partial(operator.ne, 0), "{0:4d}<0".format, const("")), GetA('n_neg')),
  65. compose(n_times, GetA('n'))))),
  66. # float members compare differently than integers d'uh, if any of l, r have n_neg!=0 compare unequal
  67. # (so even if l.n_neg == r.n_neg and n_neg > 0, they compare inequal in order to make them being
  68. # reported as having an issue ...)
  69. __eq__ = lambda l,r: FloatEq(l.weight, r.weight) and FloatEq(l.exposure, r.exposure) and l.n==r.n and (l.n_neg+r.n_neg)==0,
  70. __ne__ = lambda l,r: not (r==l),
  71. # object += other => __iadd__(self, other) => __iadd__(*a)
  72. # from each element in args get the attribute(s) of interest and sum them into args[0] (== 'object' == self)
  73. __iadd__ = lambda *a: compose(const(None), pam(*map(compose(Star, SumAttr), ['weight', 'exposure', 'n', 'n_neg'])))(a) or a[0])
  74. DefaultDict = Derived(collections.defaultdict)
  75. Source = lambda **kwargs: Obj(format=kwargs.get('format', 'Missing format='), location=kwargs.get('location', 'Missing location='),
  76. __repr__=method(compose("{0.format:>4s}: {0.location}".format, XFormAttr('format', str.upper))))
  77. SetSource = XFormAttr('source', lambda _, tp: tp) # just overwrite the attribute's value
  78. class WrapLookup(object):
  79. def __init__(self, realLUT, unknown="UnknownKey[{0}]".format):
  80. (self.LUT, self.unknown) = (realLUT, unknown)
  81. def __getitem__(self, key):
  82. return self.LUT.get(key, self.unknown(key))
  83. ########################################################################################
  84. #
  85. # Reduce a set of columns, with a snag - in order to not have to read whole
  86. # columns of a MeasurementSet - the reduction is done in chunks.
  87. #
  88. # The default chunker+slicer are geared towards grinding over a MeasurementSet
  89. # but should be flexible enough to also work - with some user level override -
  90. # on any number of "columnal" data [e.g. lists, or, as we'll see, FITS binary tables :-)]
  91. #
  92. ########################################################################################
  93. Chunk = collections.namedtuple('Chunk' , ['first' , 'nrow'])
  94. Column = collections.namedtuple('Column' , ['name' , 'slicer'])
  95. # yield Chunk() objects to cover the range [first, last) in steps of chunksize
  96. def chunkert(f, l, cs):
  97. while f<l:
  98. n = min(cs, l-f)
  99. yield Chunk(first=f, nrow=n)
  100. f = f + n
  101. raise StopIteration
  102. ## Bare-bones reduce ms. No fancy progress display
  103. # colums = [<colname0>, <colname1>, ...]
  104. # function = f(accumulator, arg0, arg1, ..., arg<len(columns)> )
  105. def reduce_ms(function, ms, init, colnames, **kwargs):
  106. # allow user to override for specific column (pass slicer_fn/4 in via 'slicer' kwarg)
  107. # slicer_fn/4 is called as (<table>, <column>, <start>, <nrow>)
  108. slicers = kwargs.get('slicers', {})
  109. # Transform the list of column names into a list of Column() object with .name, .slicer attributes
  110. columns = compose(list, Map(lambda col: Column(name=col, slicer=slicers.get(col, lambda t, c, s, n: t.getcol(c, startrow=s, nrow=n)))))(colnames)
  111. # A function that, given a particular Chunk (.first, .nrow attributes), returns those rows for all columns
  112. getcols = lambda chunk: Map(lambda col: col.slicer(ms, col.name, chunk.first, chunk.nrow))(columns)
  113. # Step over the whole table, reducing it one Chunk of rows at a time
  114. return Reduce(lambda acc, chunk: function(acc, *getcols(chunk)))(chunkert(0, len(ms), kwargs.get('chunksize', 16384)), init)
  115. # A helper generator which drains a dict and yields the items
  116. def drain_dict(d):
  117. while d: yield d.popitem()
  118. #############################################################################################################################
  119. #
  120. # MeasurementSet specific handling
  121. #
  122. #############################################################################################################################
  123. # For directly indexed MS subtables with a NAME column (eg MS::ANTENNA, MS::FIELD) form lookup table "index => Name"
  124. mk_ms_lut = lambda ms: compose(WrapLookup, dict, enumerate, Map(str.capitalize), Call('getcol', 'NAME'), partial(pt.table, ack=False), ms.getkeyword)
  125. def process_ms(path_to_ms):
  126. # Slightly modified collections.Counter - implement a ".inc()" method which returns the new value
  127. # so it's easy to check for dups (".inc(key) > 1")
  128. class MCounter(collections.Counter):
  129. def inc(self, key):
  130. self[key] += 1
  131. return self[key]
  132. timestamps = collections.defaultdict(lambda: collections.defaultdict(MCounter))
  133. # The accumulator function, taking the colums of the MS as parameters
  134. def accumulate(acc, a1, a2, fld, exp, weight, time, flag_row, dd_id):
  135. # weights < 0 will need to be masked, as well as that for rows that are flagged
  136. # weight has shape (n_row, n_pol), flag_row has (n_row,) so to properly broadcast
  137. # the flag_row we must transpose to (n_pol, n_row)
  138. n_neg = numpy.sum( weight<0, axis=1 )
  139. weight = numpy.sum( weight, axis=1 ) * numpy.array(~flag_row.view(numpy.bool), dtype=numpy.bool)
  140. n_zero = numpy.sum( weight<1e-4 )
  141. for i in range(len(a1)):
  142. key, ts, dd = ((a1[i], a2[i], fld[i]), time[i], dd_id[i])
  143. o = acc[ key ]
  144. # MS have each subband ("IF" or "BAND" in AIPS speak) in a separate
  145. # row, in FITS all IFs are stored in one row of the UV_DATA table.
  146. # Consequently, just adding EXPOSURE for each row in the MS and
  147. # INTTIM for each row in the UV_DATA aren't going to match up :-(
  148. # Need something which returns EXPOSURE once for each key we're
  149. # counting and 0 on every other occurrence.
  150. # Well, there could be *duplicate* rows in the MS! If the same correlator
  151. # data was added multiple times (accidentally). So we add a check for *that*
  152. # as well. Count up the exposure time+number of occurrences if there is no
  153. # data for the time stamp yet OR if the exact same key
  154. # (baseline, field, time, data_desc) was previously found
  155. kt_set = timestamps[key][ts]
  156. is_duplicate = kt_set.inc(dd)>1
  157. (exposure, n) = (exp[i], 1) if len(kt_set)==1 or is_duplicate else (0, 0)
  158. o.n += n
  159. o.exposure += exposure
  160. o.n_neg += n_neg[i]
  161. o.weight += weight[i]
  162. return acc
  163. with pt.table(path_to_ms, ack=True) as ms:
  164. # Collect all statistics
  165. stats = reduce_ms(accumulate, ms, collections.defaultdict(Statistics),
  166. ['ANTENNA1', 'ANTENNA2', 'FIELD_ID', 'EXPOSURE', 'WEIGHT', 'TIME', 'FLAG_ROW', 'DATA_DESC_ID'])
  167. # read antenna, field_id mappings
  168. antab, srctab = map(mk_ms_lut(ms), ['ANTENNA', 'FIELD'])
  169. # and translate (a1, a2, field_id) numbers into ('XantYant', 'FieldName')
  170. unmap_ = lambda k: (antab[k[0]]+antab[k[1]], srctab[k[2]])
  171. return SetSource(Reduce(lambda acc, kv: acc.__setitem__(unmap_(kv[0]), kv[1]) or acc)(drain_dict(stats), Derived(dict)()),
  172. Source(format='ms', location=path_to_ms))
  173. #############################################################################################################################
  174. #
  175. # FITS-IDI specific handling
  176. #
  177. #############################################################################################################################
  178. # Given "TABLE, [INDEX_COL, NAME_COL]" produce a WrapLookup object to unmap INDEX to NAME in standardized form (capitalized)
  179. # xform_columns: create column getting functions, adding special processing to the NAME_COL column
  180. xform_columns = lambda cols: map(Apply, [GetN, lambda c: compose(Map(compose(str.capitalize, str)), GetN(c))], cols)
  181. # mk_idi_lut is a function, returning a function that will create a LUT from an IDI based on "TABLE, [INDEX_COL, NAME_COL]"
  182. mk_idi_lut = lambda tbl, cols, unk: compose(partial(WrapLookup, unknown=unk), dict, ZipStar, pam(*xform_columns(cols)), GetA('data'), GetN(tbl))
  183. # input: list of (absolute) path names, output: string of grouped-by-containing-directory of the file names
  184. summarize = compose("; ".join, Map("{0[0]}: {0[1]}*".format), Map(pam(GetN(0), compose(os.path.commonprefix, list, Map(GetN(1)), GetN(1)))),
  185. GroupBy(GetN(0)), Sort(GetN(0)), Map(pam(os.path.dirname, os.path.basename)))
  186. # The accumulation function that accumulates the statistics of the UV_DATA table of one IDI file
  187. def analyze_uvdata_table(baseline, source_id, inttim, flux):
  188. acc = collections.defaultdict(Statistics)
  189. # shape of FLUX is
  190. # (nRow, dec, ra, n_if, n_chan, n_pol, nComplex)
  191. # 0 1 2 3 4 5 6
  192. assert flux.shape[-1] == 3, "This FITS-IDI file does not have a weight-per-spectral-point"
  193. # Assert that weight does not vary across channel.
  194. # Extract the weights and transpose to move the channel axis to the beginning:
  195. # (n_chan, nRow, dec, ra, n_if, n_pol)
  196. # 0 1 2 3 4 5
  197. # such that the "==" broadcasts nicely over all channels
  198. weight = flux[Ellipsis,2].transpose((4, 0, 1, 2, 3, 5))
  199. assert numpy.all(weight == weight[0, Ellipsis]), "Weight varies across channels?!"
  200. # Slice channel 0 and ra==0 and dec==0 and make masked array out of it
  201. weight = weight[0, :, 0, 0, Ellipsis]
  202. # count number of negatives per row (== per baseline, source)
  203. n_neg = numpy.sum( weight<0, axis=(2,1) )
  204. # Now sum the weights over IFs and polarizations
  205. weight = numpy.sum(weight, axis=(2,1))
  206. for i in range(len(baseline)):
  207. o = acc[ (baseline[i], source_id[i]) ]
  208. o.exposure += inttim[i]
  209. o.weight += weight[i]
  210. o.n_neg += n_neg[i]
  211. o.n += 1
  212. return acc
  213. # Helpers for dealing with FITS-IDI files
  214. # - opening an IDI file in a specific mode, adding Debug output prints the
  215. # actual file name, which is nice feedback for the usert :-)
  216. # - building antenna, source lookup tables
  217. # - Given an opened FITS-IDI object, returning the accumulated statistics from the UV_DATA table
  218. open_idi = compose(partial(astropy.io.fits.open, memmap=True, mode='readonly'), D)
  219. mk_idi_antab = mk_idi_lut('ANTENNA', ['ANTENNA_NO', 'ANNAME'], "Antenna#{0}".format)
  220. mk_idi_srctab = mk_idi_lut('SOURCE' , ['SOURCE_ID' , 'SOURCE'], "Source#{0}".format)
  221. get_idi_stats = compose(Star(analyze_uvdata_table), pam(*list(map(GetN, ['BASELINE', 'SOURCE_ID', 'INTTIM', 'FLUX']))), GetA('data'), GetN('UV_DATA'))
  222. # Actually process a series of FITS-IDI file(s)
  223. def process_idi(list_of_idis):
  224. # reduction of a single FITS-IDI file
  225. def process_one_idi(acc, path_to_idi):
  226. idi = open_idi(path_to_idi)
  227. stats = get_idi_stats(idi)
  228. # construct antenna, source unmappers
  229. antab, srctab = pam(mk_idi_antab, mk_idi_srctab)(idi)
  230. # aggregate items we found in this IDI into the global accumulator
  231. # in the process we unmap (baseline, source_id) to ('Anname1Anname2', 'Sourcename')
  232. def aggregate(a, item):
  233. a1, a2 = divmod(item[0][0], 256)
  234. a[ (antab[a1]+antab[a2], srctab[item[0][1]]) ] += item[1]
  235. return a
  236. return reduce(aggregate, drain_dict(stats), acc)
  237. # reduce all IDI files and set the source attribute on the returned object to tell whence these came
  238. return SetSource(reduce(process_one_idi, list_of_idis, DefaultDict(Statistics)), Source(format='idi', location=summarize(list_of_idis)))
  239. # take a list of dicts with dict(key=>statistics) and compare them
  240. def report(list_of_stats):
  241. # transform [Stats, Stats, ...] into [(set(keys), Stats), (set(keys), Stats), ...]
  242. # Yes. I know. The keys of a dict ARE unique by definition so the "set()" doesn't seem to
  243. # add anything. However, when finding out common or extra or missing keys across multiple data sets
  244. # the set() operations like intersect and difference are exactly what we need.
  245. # So by creating them as set() objects once makes them reusable.
  246. list_of_stats = compose(list, Map(lambda s: (set(s.keys()), s)))(list_of_stats)
  247. # can only report on common keys
  248. common_keys = reduce(set.intersection, map(GetN(0), list_of_stats))
  249. # warn about non-matching keys
  250. def count_extra_keys(acc, item): # item = (set(keys), Statistics())
  251. print("="*4, "Problem report", "="*4+"\n", item[1].source, "\n", "Extra keys:\n",
  252. "\n".join(map(compose("\t{0[0]} found {0[1]}".format, pam(identity, lambda k: n_times(item[1][k].n))), item[0])), "\n"+"="*25)
  253. return acc.append( len(item[0]) ) or acc
  254. nExtra = reduce(count_extra_keys,
  255. filter(compose(operator.truth, GetN(0)),
  256. Map(pam(compose(common_keys.__rsub__, GetN(0)), GetN(1)))(list_of_stats)),
  257. list())
  258. # For each common key check all collected values are the same by counting the number of keys
  259. # that don't fulfil this predicate. Also print all values+source whence they came in case of such a mismatch
  260. get_key = lambda k: compose(pam(GetN(k), GetA('source')), GetN(1))
  261. def check_key(acc, key):
  262. values, sources = ZipStar(map(get_key(key), list_of_stats))
  263. return acc if all_equal(values) else print(key, ":\n", "\n".join(map("\t{0} in {1}".format, values, sources))) or acc+1
  264. nProb = reduce(check_key, sorted(common_keys), 0)
  265. # And one line feedback about what we found in total
  266. print("Checked", len(list_of_stats), "data sets,", len(common_keys), "common keys",
  267. choice(partial(operator.eq, 0), const(""), "with {0} problems identified".format)(nProb),
  268. choice(operator.truth, compose("and {0[1]} non-common keys in {0[0]} formats".format, pam(len, sum)), const(""))(nExtra))
  269. return nProb
  270. # This is the main function what to do: execute all statistics-gathering functions
  271. # and if there are > 1 report the differences/errors
  272. main = compose(sys.exit, choice(compose(partial(operator.lt, 1), len), report, compose(const(0), print)), list, Map(Apply))
  273. ############################################################################
  274. #
  275. # Let argparse do some work - such that we don't have to repeat it
  276. #
  277. ############################################################################
  278. # This metaclass extracts the 'process_fn' from a Process()'s __init__'s
  279. # keyword arguments and turns it into an argparse.Action __call__() function
  280. # which appends a lambda() to the Action's destination, which, when called,
  281. # executes the processing function
  282. class ProcessMeta(type):
  283. def __call__(cls, *args, **kwargs):
  284. # Note: pop the 'process_fn=' keyword from kwargs such that
  285. # when Process.__init__() calls argparse.Action.__init__(), the
  286. # latter won't barf; it doesn't like unrecognized keyword arguments ...
  287. process_fn = kwargs.pop('process_fn', None)
  288. if process_fn is None:
  289. raise RuntimeError("Missing process_fn= keyword argument in __init__() for ", kwargs.get('help', "Unknown Process action"))
  290. # Need to do two things: add the __call__() special method to the class
  291. # which looks up the actual method-to-call in the instance
  292. # See: https://stackoverflow.com/a/33824320/26083
  293. setattr(cls, '__call__', lambda *args, **kwargs: args[0].__call__(*args, **kwargs))
  294. # and now decorate the instance with a call method consistent with
  295. # https://docs.python.org/3/library/argparse.html#action-classes
  296. # "Action instances should be callable, so subclasses must override
  297. # the __call__ method, which should accept four parameters: ..."
  298. return SetAttr(type.__call__(cls, *args, **kwargs), '__call__',
  299. lambda self, _parser, namespace, values, _opt: XFormAttr(self.dest, Append)(namespace, lambda: process_fn(values)))
  300. # Create a custom argparse.Action to add statistics gathering functions
  301. Process = Derived(with_metaclass(ProcessMeta, argparse.Action))
  302. if __name__ == '__main__':
  303. ###################################################################################################################
  304. #
  305. # Here's where the only interesting stuff happens: parsing the command line & executing stuff!
  306. #
  307. ###################################################################################################################
  308. parsert = argparse.ArgumentParser(description="Compare contents of one MeasurementSet and one or more FITS-IDI files")
  309. parsert.add_argument('--ms', action=Process, dest='path', default=list(), required=False, help="Specify the source MeasurementSet", process_fn=process_ms)
  310. parsert.add_argument('--idi', action=Process, dest='path', default=list(), nargs='*', help="The FITS-IDI file(s) produced from 'ms='", process_fn=process_idi)
  311. parsert.add_argument('--lis', action=Process, dest='path', default=list(), required=False, help="The .lis file that was used to construct 'ms='",
  312. process_fn=compose(lambda lis: SetSource(DefaultDict(Statistics), Source(format='lis', location=lis)), DD("lis-file processing not yet implemented")))
  313. main(parsert.parse_args().path)