Refactor the VMPdata_dtype_from_colIDs function

This commit is contained in:
2019-05-03 19:12:18 +02:00
parent 6b0f8b6d37
commit 1f57e48602

View File

@@ -9,7 +9,7 @@ import csv
from os import SEEK_SET from os import SEEK_SET
import time import time
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from collections import OrderedDict from collections import defaultdict, OrderedDict
import numpy as np import numpy as np
@@ -201,29 +201,46 @@ VMPdata_colID_flag_map = {
def VMPdata_dtype_from_colIDs(colIDs): def VMPdata_dtype_from_colIDs(colIDs):
"""Get a numpy record type from a list of column ID numbers.
The binary layout of the data in the MPR file is described by the sequence
of column ID numbers in the file header. This function converts that
sequence into a numpy dtype which can then be used to load data from the
file with np.frombuffer().
Some column IDs refer to small values which are packed into a single byte.
The second return value is a dict describing the bit masks with which to
extract these columns from the flags byte.
"""
type_list = [] type_list = []
field_list = [] field_name_counts = defaultdict(int)
flags_dict = OrderedDict() flags_dict = OrderedDict()
flags2_dict = OrderedDict() flags2_dict = OrderedDict()
for colID in colIDs: for colID in colIDs:
if colID in VMPdata_colID_flag_map: if colID in VMPdata_colID_flag_map:
if 'flags' not in field_list: # Some column IDs represent boolean flags or small integers
type_list.append('u1') # These are all packed into a single 'flags' byte whose position
field_list.append('flags') # in the overall record is determined by the position of the first
# column ID of flag type. If there are several flags present,
# there is still only one 'flags' int
if 'flags' not in field_name_counts:
type_list.append(('flags', 'u1'))
field_name_counts['flags'] = 1
flag_name, flag_mask, flag_type = VMPdata_colID_flag_map[colID] flag_name, flag_mask, flag_type = VMPdata_colID_flag_map[colID]
flags_dict[flag_name] = (np.uint8(flag_mask), flag_type) flags_dict[flag_name] = (np.uint8(flag_mask), flag_type)
elif colID in VMPdata_colID_dtype_map:
field_name, field_type = VMPdata_colID_dtype_map[colID]
field_name_counts[field_name] += 1
count = field_name_counts[field_name]
if count > 1:
unique_field_name = '%s %d' % (field_name, count)
else:
unique_field_name = field_name
type_list.append((unique_field_name, field_type))
else: else:
try: raise NotImplementedError("column type %d not implemented" % colID)
field = VMPdata_colID_dtype_map[colID][0] return np.dtype(type_list), flags_dict, flags2_dict
if field in field_list:
field += str(len(field_list))
field_list.append(field)
type_list.append(VMPdata_colID_dtype_map[colID][1])
except KeyError:
print(list(zip(field_list, type_list)))
raise NotImplementedError("column type %d not implemented"
% colID)
return np.dtype(list(zip(field_list, type_list))), flags_dict, flags2_dict
def read_VMP_modules(fileobj, read_module_data=True): def read_VMP_modules(fileobj, read_module_data=True):