A set of tools to verify the operation of the j2ms2 and tConvert programs
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

340 lines
21 KiB

#!/usr/bin/env python3
#######################/opt/anaconda3/bin/python3
# compare-ms-idi.py
#
# in order to gain confidence that everything from
# a MeasurementSet ended up the FITS-IDI file this
# tool accumulates statistics (currently: exposure time)
# per baseline per source.
# Deviations between source and destination numbers
# might indicate problems.
from __future__ import print_function
from six import with_metaclass # Grrrrr http://python-future.org/compatible_idioms.html#metaclasses
from functools import partial, reduce
import sys, re, collections, glob, os, operator, itertools, astropy.io.fits, numpy, argparse
import pyrap.tables as pt
# everybody should love themselves some function composition. really.
compose = lambda *fns : lambda x: reduce(lambda acc, f: f(acc), reversed(fns), x)
pam = lambda *fns : lambda x: tuple(map(lambda fn: fn(x), fns)) # fns= [f0, f1, ...] => map(x, fns) => (f0(x), f1(x), ...)
Apply = lambda *args : args[0](*args[1:]) # args= (f0, arg0, arg1,...) => f0(arg0, arg1, ...)
identity = lambda x : x
choice = lambda p, t, f : lambda x: t(x) if p(x) else f(x)
const = lambda c : lambda *_: c
method = lambda f : lambda *args, **kwargs: f(*args, **kwargs)
Map = lambda fn : partial(map, fn)
#Group = lambda n : operator.methodcaller('group', n)
GroupBy = lambda keyfn : partial(itertools.groupby, key=keyfn)
Sort = lambda keyfn : partial(sorted, key=keyfn)
Filter = lambda pred : partial(filter, pred)
Reduce = lambda *args : partial(reduce, *args)
ZipStar = lambda x : zip(*x)
Star = lambda f : lambda args: f(*args)
StarMap = lambda f : partial(itertools.starmap, f)
D = lambda x : print(x) or x
DD = lambda pfx : lambda x: print(pfx,":",x) or x
Type = lambda **kwargs : type('', (), kwargs)
Derived = lambda b, **kw : type('', (b,), kw) # Create easy derived type so attributes can be set/added
Obj = lambda **kwargs : Type(**kwargs)()
Append = lambda l, v : l.append(v) or l
SetAttr = lambda o, a, v : setattr(o, a, v) or o
XFormAttr= lambda attr, f, missing=None: lambda obj, *args, **kwargs: SetAttr(obj, attr, f(getattr(obj, attr, missing), *args, **kwargs))
# Thanks to Py3 one must sometimes drain an iterable for its side effects. Thanks guys!
# From https://docs.python.org/2/library/itertools.html#recipes
# consume(), all_equal()
consume = partial(collections.deque, maxlen=0)
all_equal= compose(lambda g: next(g, True) and not next(g, False), itertools.groupby)
# shorthands
GetN = operator.itemgetter
GetA = operator.attrgetter
Repeat = itertools.repeat
Call = operator.methodcaller
# other generic stuff (own invention)
n_times = choice(partial(operator.ne, 1), "{0:6d} times".format, const("once")) # polite integer formatting w/ special case for 1
##################################################################################
#
# this is the information we accumulate per primary key (whatever that may turn
# out to be)
#
##################################################################################
SumAttr = lambda a: lambda o0, o1: setattr(o0, a, getattr(o0, a)+getattr(o1, a))
FloatEq = lambda l, r: abs(l-r)<1e-3
Statistics = Type(weight=0, exposure=0, n=0, n_neg=0, source=None,
__repr__ = method(compose(" ".join, Filter(operator.truth),
pam("{0.exposure:9.2f}s wgt={0.weight:9.2f}".format,
compose(choice(partial(operator.ne, 0), "{0:4d}<0".format, const("")), GetA('n_neg')),
compose(n_times, GetA('n'))))),
# float members compare differently than integers d'uh, if any of l, r have n_neg!=0 compare unequal
# (so even if l.n_neg == r.n_neg and n_neg > 0, they compare inequal in order to make them being
# reported as having an issue ...)
__eq__ = lambda l,r: FloatEq(l.weight, r.weight) and FloatEq(l.exposure, r.exposure) and l.n==r.n and (l.n_neg+r.n_neg)==0,
__ne__ = lambda l,r: not (r==l),
# object += other => __iadd__(self, other) => __iadd__(*a)
# from each element in args get the attribute(s) of interest and sum them into args[0] (== 'object' == self)
__iadd__ = lambda *a: compose(const(None), pam(*map(compose(Star, SumAttr), ['weight', 'exposure', 'n', 'n_neg'])))(a) or a[0])
DefaultDict = Derived(collections.defaultdict)
Source = lambda **kwargs: Obj(format=kwargs.get('format', 'Missing format='), location=kwargs.get('location', 'Missing location='),
__repr__=method(compose("{0.format:>4s}: {0.location}".format, XFormAttr('format', str.upper))))
SetSource = XFormAttr('source', lambda _, tp: tp) # just overwrite the attribute's value
class WrapLookup(object):
def __init__(self, realLUT, unknown="UnknownKey[{0}]".format):
(self.LUT, self.unknown) = (realLUT, unknown)
def __getitem__(self, key):
return self.LUT.get(key, self.unknown(key))
########################################################################################
#
# Reduce a set of columns, with a snag - in order to not have to read whole
# columns of a MeasurementSet - the reduction is done in chunks.
#
# The default chunker+slicer are geared towards grinding over a MeasurementSet
# but should be flexible enough to also work - with some user level override -
# on any number of "columnal" data [e.g. lists, or, as we'll see, FITS binary tables :-)]
#
########################################################################################
Chunk = collections.namedtuple('Chunk' , ['first' , 'nrow'])
Column = collections.namedtuple('Column' , ['name' , 'slicer'])
# yield Chunk() objects to cover the range [first, last) in steps of chunksize
def chunkert(f, l, cs):
while f<l:
n = min(cs, l-f)
yield Chunk(first=f, nrow=n)
f = f + n
raise StopIteration
## Bare-bones reduce ms. No fancy progress display
# colums = [<colname0>, <colname1>, ...]
# function = f(accumulator, arg0, arg1, ..., arg<len(columns)> )
def reduce_ms(function, ms, init, colnames, **kwargs):
# allow user to override for specific column (pass slicer_fn/4 in via 'slicer' kwarg)
# slicer_fn/4 is called as (<table>, <column>, <start>, <nrow>)
slicers = kwargs.get('slicers', {})
# Transform the list of column names into a list of Column() object with .name, .slicer attributes
columns = compose(list, Map(lambda col: Column(name=col, slicer=slicers.get(col, lambda t, c, s, n: t.getcol(c, startrow=s, nrow=n)))))(colnames)
# A function that, given a particular Chunk (.first, .nrow attributes), returns those rows for all columns
getcols = lambda chunk: Map(lambda col: col.slicer(ms, col.name, chunk.first, chunk.nrow))(columns)
# Step over the whole table, reducing it one Chunk of rows at a time
return Reduce(lambda acc, chunk: function(acc, *getcols(chunk)))(chunkert(0, len(ms), kwargs.get('chunksize', 16384)), init)
# A helper generator which drains a dict and yields the items
def drain_dict(d):
while d: yield d.popitem()
#############################################################################################################################
#
# MeasurementSet specific handling
#
#############################################################################################################################
# For directly indexed MS subtables with a NAME column (eg MS::ANTENNA, MS::FIELD) form lookup table "index => Name"
mk_ms_lut = lambda ms: compose(WrapLookup, dict, enumerate, Map(str.capitalize), Call('getcol', 'NAME'), partial(pt.table, ack=False), ms.getkeyword)
def process_ms(path_to_ms):
# Slightly modified collections.Counter - implement a ".inc()" method which returns the new value
# so it's easy to check for dups (".inc(key) > 1")
class MCounter(collections.Counter):
def inc(self, key):
self[key] += 1
return self[key]
timestamps = collections.defaultdict(lambda: collections.defaultdict(MCounter))
# The accumulator function, taking the colums of the MS as parameters
def accumulate(acc, a1, a2, fld, exp, weight, time, flag_row, dd_id):
# weights < 0 will need to be masked, as well as that for rows that are flagged
# weight has shape (n_row, n_pol), flag_row has (n_row,) so to properly broadcast
# the flag_row we must transpose to (n_pol, n_row)
n_neg = numpy.sum( weight<0, axis=1 )
weight = numpy.sum( weight, axis=1 ) * numpy.array(~flag_row.view(numpy.bool), dtype=numpy.bool)
n_zero = numpy.sum( weight<1e-4 )
for i in range(len(a1)):
key, ts, dd = ((a1[i], a2[i], fld[i]), time[i], dd_id[i])
o = acc[ key ]
# MS have each subband ("IF" or "BAND" in AIPS speak) in a separate
# row, in FITS all IFs are stored in one row of the UV_DATA table.
# Consequently, just adding EXPOSURE for each row in the MS and
# INTTIM for each row in the UV_DATA aren't going to match up :-(
# Need something which returns EXPOSURE once for each key we're
# counting and 0 on every other occurrence.
# Well, there could be *duplicate* rows in the MS! If the same correlator
# data was added multiple times (accidentally). So we add a check for *that*
# as well. Count up the exposure time+number of occurrences if there is no
# data for the time stamp yet OR if the exact same key
# (baseline, field, time, data_desc) was previously found
kt_set = timestamps[key][ts]
is_duplicate = kt_set.inc(dd)>1
(exposure, n) = (exp[i], 1) if len(kt_set)==1 or is_duplicate else (0, 0)
o.n += n
o.exposure += exposure
o.n_neg += n_neg[i]
o.weight += weight[i]
return acc
with pt.table(path_to_ms, ack=True) as ms:
# Collect all statistics
stats = reduce_ms(accumulate, ms, collections.defaultdict(Statistics),
['ANTENNA1', 'ANTENNA2', 'FIELD_ID', 'EXPOSURE', 'WEIGHT', 'TIME', 'FLAG_ROW', 'DATA_DESC_ID'])
# read antenna, field_id mappings
antab, srctab = map(mk_ms_lut(ms), ['ANTENNA', 'FIELD'])
# and translate (a1, a2, field_id) numbers into ('XantYant', 'FieldName')
unmap_ = lambda k: (antab[k[0]]+antab[k[1]], srctab[k[2]])
return SetSource(Reduce(lambda acc, kv: acc.__setitem__(unmap_(kv[0]), kv[1]) or acc)(drain_dict(stats), Derived(dict)()),
Source(format='ms', location=path_to_ms))
#############################################################################################################################
#
# FITS-IDI specific handling
#
#############################################################################################################################
# Given "TABLE, [INDEX_COL, NAME_COL]" produce a WrapLookup object to unmap INDEX to NAME in standardized form (capitalized)
# xform_columns: create column getting functions, adding special processing to the NAME_COL column
xform_columns = lambda cols: map(Apply, [GetN, lambda c: compose(Map(compose(str.capitalize, str)), GetN(c))], cols)
# mk_idi_lut is a function, returning a function that will create a LUT from an IDI based on "TABLE, [INDEX_COL, NAME_COL]"
mk_idi_lut = lambda tbl, cols, unk: compose(partial(WrapLookup, unknown=unk), dict, ZipStar, pam(*xform_columns(cols)), GetA('data'), GetN(tbl))
# input: list of (absolute) path names, output: string of grouped-by-containing-directory of the file names
summarize = compose("; ".join, Map("{0[0]}: {0[1]}*".format), Map(pam(GetN(0), compose(os.path.commonprefix, list, Map(GetN(1)), GetN(1)))),
GroupBy(GetN(0)), Sort(GetN(0)), Map(pam(os.path.dirname, os.path.basename)))
# The accumulation function that accumulates the statistics of the UV_DATA table of one IDI file
def analyze_uvdata_table(baseline, source_id, inttim, flux):
acc = collections.defaultdict(Statistics)
# shape of FLUX is
# (nRow, dec, ra, n_if, n_chan, n_pol, nComplex)
# 0 1 2 3 4 5 6
assert flux.shape[-1] == 3, "This FITS-IDI file does not have a weight-per-spectral-point"
# Assert that weight does not vary across channel.
# Extract the weights and transpose to move the channel axis to the beginning:
# (n_chan, nRow, dec, ra, n_if, n_pol)
# 0 1 2 3 4 5
# such that the "==" broadcasts nicely over all channels
weight = flux[Ellipsis,2].transpose((4, 0, 1, 2, 3, 5))
assert numpy.all(weight == weight[0, Ellipsis]), "Weight varies across channels?!"
# Slice channel 0 and ra==0 and dec==0 and make masked array out of it
weight = weight[0, :, 0, 0, Ellipsis]
# count number of negatives per row (== per baseline, source)
n_neg = numpy.sum( weight<0, axis=(2,1) )
# Now sum the weights over IFs and polarizations
weight = numpy.sum(weight, axis=(2,1))
for i in range(len(baseline)):
o = acc[ (baseline[i], source_id[i]) ]
o.exposure += inttim[i]
o.weight += weight[i]
o.n_neg += n_neg[i]
o.n += 1
return acc
# Helpers for dealing with FITS-IDI files
# - opening an IDI file in a specific mode, adding Debug output prints the
# actual file name, which is nice feedback for the usert :-)
# - building antenna, source lookup tables
# - Given an opened FITS-IDI object, returning the accumulated statistics from the UV_DATA table
open_idi = compose(partial(astropy.io.fits.open, memmap=True, mode='readonly'), D)
mk_idi_antab = mk_idi_lut('ANTENNA', ['ANTENNA_NO', 'ANNAME'], "Antenna#{0}".format)
mk_idi_srctab = mk_idi_lut('SOURCE' , ['SOURCE_ID' , 'SOURCE'], "Source#{0}".format)
get_idi_stats = compose(Star(analyze_uvdata_table), pam(*list(map(GetN, ['BASELINE', 'SOURCE_ID', 'INTTIM', 'FLUX']))), GetA('data'), GetN('UV_DATA'))
# Actually process a series of FITS-IDI file(s)
def process_idi(list_of_idis):
# reduction of a single FITS-IDI file
def process_one_idi(acc, path_to_idi):
idi = open_idi(path_to_idi)
stats = get_idi_stats(idi)
# construct antenna, source unmappers
antab, srctab = pam(mk_idi_antab, mk_idi_srctab)(idi)
# aggregate items we found in this IDI into the global accumulator
# in the process we unmap (baseline, source_id) to ('Anname1Anname2', 'Sourcename')
def aggregate(a, item):
a1, a2 = divmod(item[0][0], 256)
a[ (antab[a1]+antab[a2], srctab[item[0][1]]) ] += item[1]
return a
return reduce(aggregate, drain_dict(stats), acc)
# reduce all IDI files and set the source attribute on the returned object to tell whence these came
return SetSource(reduce(process_one_idi, list_of_idis, DefaultDict(Statistics)), Source(format='idi', location=summarize(list_of_idis)))
# take a list of dicts with dict(key=>statistics) and compare them
def report(list_of_stats):
# transform [Stats, Stats, ...] into [(set(keys), Stats), (set(keys), Stats), ...]
# Yes. I know. The keys of a dict ARE unique by definition so the "set()" doesn't seem to
# add anything. However, when finding out common or extra or missing keys across multiple data sets
# the set() operations like intersect and difference are exactly what we need.
# So by creating them as set() objects once makes them reusable.
list_of_stats = compose(list, Map(lambda s: (set(s.keys()), s)))(list_of_stats)
# can only report on common keys
common_keys = reduce(set.intersection, map(GetN(0), list_of_stats))
# warn about non-matching keys
def count_extra_keys(acc, item): # item = (set(keys), Statistics())
print("="*4, "Problem report", "="*4+"\n", item[1].source, "\n", "Extra keys:\n",
"\n".join(map(compose("\t{0[0]} found {0[1]}".format, pam(identity, lambda k: n_times(item[1][k].n))), item[0])), "\n"+"="*25)
return acc.append( len(item[0]) ) or acc
nExtra = reduce(count_extra_keys,
filter(compose(operator.truth, GetN(0)),
Map(pam(compose(common_keys.__rsub__, GetN(0)), GetN(1)))(list_of_stats)),
list())
# For each common key check all collected values are the same by counting the number of keys
# that don't fulfil this predicate. Also print all values+source whence they came in case of such a mismatch
get_key = lambda k: compose(pam(GetN(k), GetA('source')), GetN(1))
def check_key(acc, key):
values, sources = ZipStar(map(get_key(key), list_of_stats))
return acc if all_equal(values) else print(key, ":\n", "\n".join(map("\t{0} in {1}".format, values, sources))) or acc+1
nProb = reduce(check_key, sorted(common_keys), 0)
# And one line feedback about what we found in total
print("Checked", len(list_of_stats), "data sets,", len(common_keys), "common keys",
choice(partial(operator.eq, 0), const(""), "with {0} problems identified".format)(nProb),
choice(operator.truth, compose("and {0[1]} non-common keys in {0[0]} formats".format, pam(len, sum)), const(""))(nExtra))
return nProb
# This is the main function what to do: execute all statistics-gathering functions
# and if there are > 1 report the differences/errors
main = compose(sys.exit, choice(compose(partial(operator.lt, 1), len), report, compose(const(0), print)), list, Map(Apply))
############################################################################
#
# Let argparse do some work - such that we don't have to repeat it
#
############################################################################
# This metaclass extracts the 'process_fn' from a Process()'s __init__'s
# keyword arguments and turns it into an argparse.Action __call__() function
# which appends a lambda() to the Action's destination, which, when called,
# executes the processing function
class ProcessMeta(type):
def __call__(cls, *args, **kwargs):
# Note: pop the 'process_fn=' keyword from kwargs such that
# when Process.__init__() calls argparse.Action.__init__(), the
# latter won't barf; it doesn't like unrecognized keyword arguments ...
process_fn = kwargs.pop('process_fn', None)
if process_fn is None:
raise RuntimeError("Missing process_fn= keyword argument in __init__() for ", kwargs.get('help', "Unknown Process action"))
# Need to do two things: add the __call__() special method to the class
# which looks up the actual method-to-call in the instance
# See: https://stackoverflow.com/a/33824320/26083
setattr(cls, '__call__', lambda *args, **kwargs: args[0].__call__(*args, **kwargs))
# and now decorate the instance with a call method consistent with
# https://docs.python.org/3/library/argparse.html#action-classes
# "Action instances should be callable, so subclasses must override
# the __call__ method, which should accept four parameters: ..."
return SetAttr(type.__call__(cls, *args, **kwargs), '__call__',
lambda self, _parser, namespace, values, _opt: XFormAttr(self.dest, Append)(namespace, lambda: process_fn(values)))
# Create a custom argparse.Action to add statistics gathering functions
Process = Derived(with_metaclass(ProcessMeta, argparse.Action))
if __name__ == '__main__':
###################################################################################################################
#
# Here's where the only interesting stuff happens: parsing the command line & executing stuff!
#
###################################################################################################################
parsert = argparse.ArgumentParser(description="Compare contents of one MeasurementSet and one or more FITS-IDI files")
parsert.add_argument('--ms', action=Process, dest='path', default=list(), required=False, help="Specify the source MeasurementSet", process_fn=process_ms)
parsert.add_argument('--idi', action=Process, dest='path', default=list(), nargs='*', help="The FITS-IDI file(s) produced from 'ms='", process_fn=process_idi)
parsert.add_argument('--lis', action=Process, dest='path', default=list(), required=False, help="The .lis file that was used to construct 'ms='",
process_fn=compose(lambda lis: SetSource(DefaultDict(Statistics), Source(format='lis', location=lis)), DD("lis-file processing not yet implemented")))
main(parsert.parse_args().path)