Source code for

# Copyright 2014 Diamond Light Source Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

.. module:: data
   :platform: Unix
   :synopsis: The Data class dynamically inherits from transport specific data\
   class and holds the data array, along with associated information.

.. moduleauthor:: Nicola Wadeson <>


import savu.core.utils as cu
from import MetaData
import as dsu
from import Preview
from import DataCreate

[docs]class Data(DataCreate): """The Data class dynamically inherits from transport specific data class and holds the data array, along with associated information. """ def __init__(self, name, exp): super(Data, self).__init__(name) self.meta_data = MetaData() self.related = {} self.pattern_list = self.__get_available_pattern_list() self.data_info = MetaData() self.__initialise_data_info(name) self._preview = Preview(self) self.exp = exp self.group_name = None = None self._plugin_data_obj = None self.raw = None self.backing_file = None = None self.next_shape = None self.orig_shape = None self.previous_pattern = None self.transport_data = None # def get_data(self, related): # return self.related[related].data def __initialise_data_info(self, name): """ Initialise entries in the data_info meta data. """ self.data_info.set('name', name) self.data_info.set('data_patterns', {}) self.data_info.set('shape', None) self.data_info.set('nDims', None) def _set_plugin_data(self, plugin_data_obj): """ Encapsulate a PluginData object. """ self._plugin_data_obj = plugin_data_obj def _clear_plugin_data(self): """ Set encapsulated PluginData object to None. """ self._plugin_data_obj = None def _get_plugin_data(self): """ Get encapsulated PluginData object. """ if self._plugin_data_obj is not None: return self._plugin_data_obj else: raise Exception("There is no PluginData object associated with " "the Data object.")
[docs] def get_preview(self): """ Get the Preview instance associated with the data object """ return self._preview
def _set_transport_data(self, transport): """ Import the data transport mechanism :returns: instance of data transport :rtype: transport_data """ transport_data = "" + transport + \ "_transport_data" transport_data = cu.import_class(transport_data) self.transport_data = transport_data(self) self.data_info.set('transport', transport) def _get_transport_data(self): return self.transport_data def __deepcopy__(self, memo): """ Copy the data object. """ name = self.data_info.get('name') return dsu._deepcopy_data_object(self, Data(name, self.exp))
[docs] def get_data_patterns(self): """ Get data patterns associated with this data object. :returns: A dictionary of associated patterns. :rtype: dict """ return self.data_info.get('data_patterns')
def _set_previous_pattern(self, pattern): self.previous_pattern = pattern
[docs] def get_previous_pattern(self): return self.previous_pattern
[docs] def set_shape(self, shape): """ Set the dataset shape. """ self.data_info.set('shape', shape) self.__check_dims()
[docs] def set_original_shape(self, shape): """ Set the original data shape before previewing """ self.orig_shape = shape self.set_shape(shape)
[docs] def get_original_shape(self): """ Returns the original shape of the data before previewing Returns ------- tuple Original data shape. """ return self.orig_shape
[docs] def get_shape(self): """ Get the dataset shape :returns: data shape :rtype: tuple """ shape = self.data_info.get('shape') return shape
def __check_dims(self): """ Check the ``shape`` and ``nDims`` entries in the data_info meta_data dictionary are equal. """ nDims = self.data_info.get("nDims") shape = self.data_info.get('shape') if nDims: if len(shape) != nDims: error_msg = ("The number of axis labels, %d, does not " "coincide with the number of data " "dimensions %d." % (nDims, len(shape))) raise Exception(error_msg) def _set_name(self, name): self.data_info.set('name', name)
[docs] def get_name(self, orig=False): """ Get data name. :keyword bool orig: Set this flag to true to return the original cloned dataset name if this dataset is a clone :returns: the name associated with the dataset :rtype: str """ if orig: dinfo = self.data_info.get_dictionary() return dinfo['clone'] if 'clone' in dinfo.keys() else dinfo['name'] return self.data_info.get('name')
def __get_available_pattern_list(self): """ Get a list of ALL pattern names that are currently allowed in the framework. """ pattern_list = dsu.get_available_pattern_types() return pattern_list
[docs] def add_pattern(self, dtype, **kwargs): """ Add a pattern. :params str dtype: The *type* of pattern to add, which can be anything from the :const:`` :const:`pattern_list` :data:`` :data:`pattern_list`: :keyword tuple core_dims: Dimension indices of core dimensions :keyword tuple slice_dims: Dimension indices of slice dimensions """ if dtype in self.pattern_list: nDims = 0 for args in kwargs: dlen = len(kwargs[args]) if not dlen: raise Exception("Pattern Error: Pattern %s must have at" " least one %s" % (dtype, args)) nDims += len(kwargs[args]) self.data_info.set(['data_patterns', dtype, args], kwargs[args]) self.__convert_pattern_dimensions(dtype) if self.get_shape(): diff = len(self.get_shape()) - nDims if diff: pattern = {dtype: self.get_data_patterns()[dtype]} self._set_data_patterns(pattern) nDims += diff try: if nDims != self.data_info.get("nDims"): actualDims = self.data_info.get('nDims') err_msg = ("The pattern %s has an incorrect number of " "dimensions: %d required but %d specified." % (dtype, actualDims, nDims)) raise Exception(err_msg) except KeyError: self.data_info.set('nDims', nDims) else: raise Exception("The data pattern '%s'does not exist. Please " "choose from the following list: \n'%s'", dtype, str(self.pattern_list))
[docs] def add_volume_patterns(self, x, y, z): """ Adds volume patterns :params int x: dimension to be associated with x-axis :params int y: dimension to be associated with y-axis :params int z: dimension to be associated with z-axis """ self.add_pattern("VOLUME_XZ", **self.__get_dirs_for_volume(x, z, y)) if y: self.add_pattern( "VOLUME_YZ", **self.__get_dirs_for_volume(y, z, x)) self.add_pattern( "VOLUME_XY", **self.__get_dirs_for_volume(x, y, z)) if self.data_info.get("nDims") > 3 and y: self.add_pattern("VOLUME_3D", **self.__get_dirs_for_volume_3D())
def __get_dirs_for_volume(self, dim1, dim2, sdir, dim3=None): """ Calculate core_dir and slice_dir for a volume pattern. """ all_dims = range(self.data_info.get("nDims")) vol_dict = {} vol_dict['core_dims'] = (dim1, dim2) slice_dir = [sdir] if type(sdir) is int else [] for ddir in all_dims: if ddir not in [dim1, dim2, sdir]: slice_dir.append(ddir) vol_dict['slice_dims'] = tuple(slice_dir) return vol_dict def __get_dirs_for_volume_3D(self): # create volume 3D pattern here patterns = self.get_data_patterns() cdim = [] for v in ['VOLUME_YZ', 'VOLUME_XY', 'VOLUME_XZ']: cdim += (patterns[v]['core_dims']) cdim = set(cdim) sdim = tuple(set(range(self.data_info.get("nDims"))).difference(cdim)) return {"core_dims": tuple(cdim), "slice_dims": sdim}
[docs] def set_axis_labels(self, *args): """ Set the axis labels associated with each data dimension. :arg str: Each arg should be of the form ``name.unit``. If ``name`` is\ a data_obj.meta_data entry, it will be output to the final .nxs file. """ self.data_info.set('nDims', len(args)) axis_labels = [] for arg in args: if isinstance(arg, dict): axis_labels.append(arg) else: try: axis = arg.split('.') axis_labels.append({axis[0]: axis[1]}) except: # data arrives here, but that may be an error pass self.data_info.set('axis_labels', axis_labels)
[docs] def get_axis_labels(self): """ Get axis labels. :returns: Axis labels :rtype: list(dict) """ return self.data_info.get('axis_labels')
[docs] def get_data_dimension_by_axis_label(self, name, contains=False, exists=False): """ Get the dimension of the data associated with a particular axis_label. :param str name: The name of the axis_label :keyword bool contains: Set this flag to true if the name is only part of the axis_label name :keyword bool exists: Set to True to return False rather than Exception :returns: The associated axis number :rtype: int """ axis_labels = self.data_info.get('axis_labels') for i in range(len(axis_labels)): if contains is True: for names in list(axis_labels[i].keys()): if name in names: return i else: if name in list(axis_labels[i].keys()): return i if exists: return False raise Exception("Cannot find the specifed axis label.")
def _finalise_patterns(self): """ Adds a main axis (fastest changing) to SINOGRAM and PROJECTON patterns. """ check = 0 check += self.__check_pattern('SINOGRAM') check += self.__check_pattern('PROJECTION') if check == 2 and len(self.get_shape()) > 2: self.__set_main_axis('SINOGRAM') self.__set_main_axis('PROJECTION') def __check_pattern(self, pattern_name): """ Check if a pattern exists. """ patterns = self.get_data_patterns() try: patterns[pattern_name] except KeyError: return 0 return 1 def __convert_pattern_dimensions(self, dtype): """ Replace negative indices in pattern kwargs. """ pattern = self.get_data_patterns()[dtype] if 'main_dir' in list(pattern.keys()): del pattern['main_dir'] nDims = sum([len(i) for i in list(pattern.values())]) for p in pattern: ddirs = pattern[p] pattern[p] = self._non_negative_directions(ddirs, nDims) def _non_negative_directions(self, ddirs, nDims): """ Replace negative indexing values with positive counterparts. :params tuple(int) ddirs: data dimension indices :params int nDims: The number of data dimensions :returns: non-negative data dimension indices :rtype: tuple(int) """ index = [i for i in range(len(ddirs)) if ddirs[i] < 0] list_ddirs = list(ddirs) for i in index: list_ddirs[i] = nDims + ddirs[i] return tuple(list_ddirs) def __set_main_axis(self, pname): """ Set the ``main_dir`` pattern kwarg to the fastest changing dimension """ patterns = self.get_data_patterns() n1 = 'PROJECTION' if pname == 'SINOGRAM' else 'SINOGRAM' d1 = patterns[n1]['core_dims'] d2 = patterns[pname]['slice_dims'] tdir = set(d1).intersection(set(d2)) # this is required when a single sinogram exists in the mm case, and a # dimension is added via parameter tuning. if not tdir: tdir = [d2[0]] self.data_info.set(['data_patterns', pname, 'main_dir'], list(tdir)[0])
[docs] def get_axis_label_keys(self): """ Get axis_label names :returns: A list containing associated axis names for each dimension :rtype: list(str) """ axis_labels = self.data_info.get('axis_labels') axis_label_keys = [] for labels in axis_labels: for key in list(labels.keys()): axis_label_keys.append(key) return axis_label_keys
[docs] def amend_axis_label_values(self, slice_list): """ Amend all axis label values based on the slice_list parameter.\ This is required if the data is reduced. """ axis_labels = self.get_axis_labels() for i in range(len(slice_list)): label = list(axis_labels[i].keys())[0] if label in list(self.meta_data.get_dictionary().keys()): values = self.meta_data.get(label) preview_sl = [slice(None)]*len(values.shape) preview_sl[0] = slice_list[i] self.meta_data.set(label, values[tuple(preview_sl)])
[docs] def get_core_dimensions(self): """ Get the core data dimensions associated with the current pattern. :returns: value associated with pattern key ``core_dims`` :rtype: tuple """ return list(self._get_plugin_data().get_pattern().values())[0]['core_dims']
[docs] def get_slice_dimensions(self): """ Get the slice data dimensions associated with the current pattern. :returns: value associated with pattern key ``slice_dims`` :rtype: tuple """ return list(self._get_plugin_data().get_pattern().values())[0]['slice_dims']
[docs] def get_itemsize(self): """ Returns bytes per entry """ dtype = self.get_dtype() if not dtype: self.set_dtype(None) dtype = self.get_dtype() return self.get_dtype().itemsize