Source code for savu.data.data_structures.data_create

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: data_create
   :platform: Unix
   :synopsis: A class inherited by Data class that deals with data object \
   creation.

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""
import copy
import h5py
import numpy as np

from savu.core.utils import docstring_parameter
import savu.data.data_structures.data_notes as notes


[docs]class DataCreate(object): """ Class that deals with creating a data object. """ def __init__(self, name='DataCreate'): self.dtype = None self.remove = False
[docs] @docstring_parameter(notes._create.__doc__, notes._shape.__doc__, notes.axis_labels.__doc__, notes.patterns.__doc__) def create_dataset(self, *args, **kwargs): """ Set up required information when an output dataset has been created by a plugin. :arg Data: A data object :keyword tuple shape: The shape of the dataset :keyword list axis_labels: The axis_labels associated with the datasets :keyword patterns: The patterns associated with the dataset (optional,\ see note below) :keyword type dtype: Type of the data (optional: Defaults to \ np.float32) :keyword bool remove: Remove from framework after completion \ (no link in .nxs file) (optional: Defaults to False.) :keyword bool raw: Keep dark and flats (ImageKey or NoImageKey) {0} \n {1} \n {2} \n {3} """ #self.dtype = self.set_dtype(kwargs.get('dtype', np.float32)) self.remove = kwargs.get('remove', False) self.raw = kwargs.get('raw', False) self.transport = kwargs.get('transport', None) if self.transport: self.exp.meta_data.set('transport', self.transport) self._set_transport_data(self.transport) if len(args) == 1: self.__create_dataset_from_object(args[0]) else: self.__create_dataset_from_kwargs(kwargs) self.get_preview().set_preview([])
[docs] def set_dtype(self, dtype): if not dtype: if not self.data: plugin = self._get_plugin_data()._plugin.__class__.__name__ raise Exception("Please create all output datasets before " "setting plugin data in %s plugin.\n" % plugin) elif hasattr(self.data, 'dtype'): dtype = self.data.dtype else: h5 = h5py._hl.dataset.Dataset dtype = self.data.dtype if isinstance(self.data, h5) else \ self.data.data.dtype self.dtype = np.dtype(dtype)
[docs] def get_dtype(self): return self.dtype
def __create_dataset_from_object(self, data_obj): """ Create a dataset from an existing Data object. """ patterns = copy.deepcopy(data_obj.get_data_patterns()) self.__copy_labels(data_obj) self.__find_and_set_shape(data_obj) self._set_data_patterns(patterns) self.raw = data_obj.data if self.raw else None def __create_dataset_from_kwargs(self, kwargs): """ Create dataset from kwargs. """ try: shape = kwargs['shape'] self.__create_axis_labels(kwargs['axis_labels']) except KeyError: raise Exception("Please state axis_labels and shape when " "creating a new dataset") if isinstance(shape, DataCreate): self.__find_and_set_shape(shape) else: pData = self._get_plugin_data() pData._set_shape_before_tuning(copy.copy(shape)) self.set_shape(shape + tuple(pData.extra_dims)) if 'patterns' in kwargs: patterns = self.__copy_patterns(kwargs['patterns']) self._set_data_patterns(patterns) def __copy_patterns(self, copy_data): """ Copy patterns """ if isinstance(copy_data, DataCreate): patterns = copy.deepcopy(copy_data.get_data_patterns()) else: data = list(copy_data.keys())[0] pattern_list = copy_data[data] all_patterns = copy.deepcopy(data.get_data_patterns()) if len(pattern_list[0].split('.')) > 1: patterns = self.__copy_patterns_removing_dimensions( pattern_list, all_patterns, len(data.get_shape())) else: patterns = {} for pattern in pattern_list: patterns[pattern] = all_patterns[pattern] return patterns def __copy_patterns_removing_dimensions(self, pattern_list, all_patterns, nDims): """ Copy patterns but remove specified dimensions from them. """ copy_patterns = {} for new_pattern in pattern_list: name, all_dims = new_pattern.split('.') if name == '*': copy_patterns = all_patterns else: copy_patterns[name] = all_patterns[name] dims = tuple(map(int, all_dims.split(','))) dims = self._non_negative_directions(dims, nDims=nDims) dim_map = [a for a in range(nDims) if a not in dims] patterns = {} for name, pattern_dict in copy_patterns.items(): empty_flag = False for ddir in ['slice_dims', 'core_dims']: s_dims = self._non_negative_directions( pattern_dict[ddir], nDims=nDims) new_dims = [sd for sd in s_dims if sd not in dims] new_dims = [dim_map.index(n) for n in new_dims] pattern_dict[ddir] = tuple(new_dims) if not new_dims: empty_flag = True if empty_flag is False: patterns[name] = pattern_dict return patterns def __create_axis_labels(self, axis_labels): """ Create axis labels. """ if isinstance(axis_labels, DataCreate): self.__copy_labels(axis_labels) elif isinstance(axis_labels, dict): data = list(axis_labels.keys())[0] self.__copy_labels(data) self.__amend_axis_labels(axis_labels[data]) else: self.set_axis_labels(*axis_labels) # if parameter tuning if self._get_plugin_data().multi_params_dict: self.__add_extra_dims_labels() def __copy_labels(self, copy_data): """ Copy axis labels. """ nDims = copy.copy(copy_data.data_info.get('nDims')) axis_labels = \ copy.copy(copy_data.data_info.get('axis_labels')) self.data_info.set('nDims', nDims) self.data_info.set('axis_labels', axis_labels) # if parameter tuning if self._get_plugin_data().multi_params_dict: self.__add_extra_dims_labels() def __add_extra_dims_labels(self): """ Add axis labels to extra dimensions created by parameter tuning. """ params_dict = self._get_plugin_data().multi_params_dict # add multi_params axis labels from dictionary in pData nDims = self.data_info.get('nDims') axis_labels = self.data_info.get('axis_labels') axis_labels.extend([0]*len(params_dict)) for key, value in params_dict.items(): title = value['label'] name, unit = title.split('.') axis_labels[nDims + key] = {name: unit} # add parameter values to the meta_data self.meta_data.set(name, np.array(value['values'])) self.data_info.set('nDims', nDims + len(self.extra_dims)) self.data_info.set('axis_labels', axis_labels) def __amend_axis_labels(self, *args): """ Helper function to remove, replace/add or insert axis_labels into existing axis_labels """ removed_dims = 0 for arg in args[0]: if len(arg.split('.')) == 1: self.__remove_axis_labels(arg, removed_dims) removed_dims += 1 else: if len(arg.split('~')) == 1: self.__replace_axis_labels(arg) else: self.__insert_axis_labels(arg) def __remove_axis_labels(self, label, removed_dims): """ Remove axis labels. """ axis_labels = self.get_axis_labels() del axis_labels[int(label) - removed_dims] self.data_info.set( 'nDims', self.data_info.get('nDims') - 1) def __replace_axis_labels(self, label): """ Replace or add axis labels. """ axis_labels = self.get_axis_labels() label = label.split('.') axis_labels[int(label[0])] = {label[1]: label[2]} if int(label[0]) > self.data_info.get('nDims'): self.data_info.set( 'nDims', self.data_info.get('nDims') + 1) def __insert_axis_labels(self, label): """ Insert axis labels. """ axis_labels = self.get_axis_labels() label = label.split('~')[1].split('.') axis_labels.insert(int(label[0]), {label[1]: label[2]}) self.data_info.set('nDims', self.data_info.get('nDims') + 1) def _set_data_patterns(self, patterns): """ Add missing dimensions to patterns and populate data info dict. """ all_dims = list(range(len(self.get_shape()))) for p in patterns: pDims = patterns[p]['core_dims'] + patterns[p]['slice_dims'] for dim in all_dims: if dim not in pDims: patterns[p]['slice_dims'] += (dim,) self.data_info.set('data_patterns', patterns) def __find_and_set_shape(self, data): """ Add any extra dimensions, from parameter tuning, to the shape and set the shape in the framework. """ pData = self._get_plugin_data() pData._set_shape_before_tuning(copy.copy(data.get_shape())) new_shape = copy.copy(data.get_shape()) + tuple(pData.extra_dims) self.set_shape(new_shape)