Made the individual components of the "flags" and "flags2" columns accessible

New function MPRfile.get_flag(flagname)
This commit is contained in:
Chris Kerr
2014-05-01 23:02:23 +01:00
parent c302ec9117
commit cfffeee2e2
2 changed files with 38 additions and 16 deletions

View File

@@ -131,11 +131,29 @@ VMPmodule_hdr = np.dtype([('shortname', 'S10'),
def VMPdata_dtype_from_colIDs(colIDs):
dtype_dict = OrderedDict()
flags_dict = OrderedDict()
flags2_dict = OrderedDict()
for colID in colIDs:
if colID in (1, 2, 3, 21, 31, 65):
dtype_dict['flags'] = 'u1'
if colID == 1:
flags_dict['mode'] = (np.uint8(0x03), np.uint8)
elif colID == 2:
flags_dict['ox/red'] = (np.uint8(0x04), np.bool_)
elif colID == 3:
flags_dict['error'] = (np.uint8(0x08), np.bool_)
elif colID == 21:
flags_dict['control changes'] = (np.uint8(0x10), np.bool_)
elif colID == 31:
flags_dict['Ns changes'] = (np.uint8(0x20), np.bool_)
elif colID == 65:
flags_dict['counter inc.'] = (np.uint8(0x80), np.bool_)
else:
raise NotImplementedError("flag %d not implemented" % colID)
elif colID in (131,):
dtype_dict['flags2'] = '<u2'
if colID == 131:
flags2_dict['??'] = (np.uint16(0x0001), np.bool_)
elif colID == 4:
dtype_dict['time/s'] = '<f8'
elif colID == 5:
@@ -159,7 +177,7 @@ def VMPdata_dtype_from_colIDs(colIDs):
dtype_dict['(Q-Qo)/C'] = '<f4'
else:
raise NotImplementedError("column type %d not implemented" % colID)
return np.dtype(list(dtype_dict.items()))
return np.dtype(list(dtype_dict.items())), flags_dict, flags2_dict
def read_VMP_modules(fileobj, read_module_data=True):
@@ -253,7 +271,7 @@ class MPRfile:
else:
assert(not any(remaining_headers))
self.dtype = VMPdata_dtype_from_colIDs(column_types)
self.dtype, self.flags_dict, self.flags2_dict = VMPdata_dtype_from_colIDs(column_types)
self.data = np.fromstring(main_data, dtype=self.dtype)
assert(self.data.shape[0] == n_data_points)
@@ -291,3 +309,13 @@ class MPRfile:
Start date: %s
End date: %s
Timestamp: %s""" % (self.startdate, self.enddate, self.timestamp))
def get_flag(self, flagname):
if flagname in self.flags_dict:
mask, dtype = self.flags_dict[flagname]
return np.array(self.data['flags'] & mask, dtype=dtype)
elif flagname in self.flags2_dict:
mask, dtype = self.flags2_dict[flagname]
return np.array(self.data['flags2'] & mask, dtype=dtype)
else:
raise AttributeError("Flag '%s' not present" % flagname)

View File

@@ -111,19 +111,14 @@ def assert_MPR_matches_MPT(mpr, mpt, comments):
if fieldname in mpr.dtype.fields:
assert_array_equal(mpr.data[fieldname], mpt[fieldname])
assert_array_equal(mpr.data["flags"] & 0x03, mpt["mode"])
assert_array_equal(np.array(mpr.data["flags"] & 0x04, dtype=np.bool_),
mpt["ox/red"])
assert_array_equal(np.array(mpr.data["flags"] & 0x08, dtype=np.bool_),
mpt["error"])
assert_array_equal(np.array(mpr.data["flags"] & 0x10, dtype=np.bool_),
mpt["control changes"])
if "Ns changes" in mpt.dtype.fields:
assert_array_equal(np.array(mpr.data["flags"] & 0x20, dtype=np.bool_),
mpt["Ns changes"])
## Nothing uses the 0x40 bit of the flags
assert_array_equal(np.array(mpr.data["flags"] & 0x80, dtype=np.bool_),
mpt["counter inc."])
assert_array_equal(mpr.get_flag("mode"), mpt["mode"])
assert_array_equal(mpr.get_flag("ox/red"), mpt["ox/red"])
assert_array_equal(mpr.get_flag("error"), mpt["error"])
assert_array_equal(mpr.get_flag("control changes"), mpt["control changes"])
if "Ns changes" in mpt.dtype.fields:
assert_array_equal(mpr.get_flag("Ns changes"), mpt["Ns changes"])
## Nothing uses the 0x40 bit of the flags
assert_array_equal(mpr.get_flag("counter inc."), mpt["counter inc."])
assert_array_almost_equal(mpr.data["time/s"],
mpt["time/s"],
@@ -147,7 +142,6 @@ def assert_MPR_matches_MPT(mpr, mpt, comments):
eq_(timestamp_from_comments(comments), mpr.timestamp)
except AttributeError:
pass
def test_MPR1_matches_MPT1():