Source code for plugins.loaders.yaml_converter

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: yaml_converter
   :platform: Unix
   :synopsis: 'A class to load data from a non-standard nexus/hdf5 file using \
               descriptions loaded from a yaml file.'

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""

import os
import h5py
import yaml
import copy
import logging
import collections.abc as collections
import numpy as np  # used in exec so do not delete
from ast import literal_eval

import savu.plugins.utils as pu
import savu.plugins.loaders.utils.yaml_utils as yu
from savu.plugins.loaders.base_loader import BaseLoader
from savu.data.experiment_collection import Experiment


[docs]class YamlConverter(BaseLoader): def __init__(self, name='YamlConverter'): super(YamlConverter, self).__init__(name)
[docs] def setup(self, template=False, metadata=True): # Read YAML file yfile = self.parameters['yaml_file'] data_dict = yu.read_yaml(self._get_yaml_file(yfile)) data_dict = self._check_for_inheritance(data_dict, {}) self._check_for_imports(data_dict) data_dict.pop('inherit', None) data_dict.pop('import', None) if template: return data_dict data_dict = self._add_template_updates(data_dict) self._set_entries(data_dict)
def _get_yaml_file(self, yaml_file): if yaml_file is None: raise Exception('Please pass a yaml file to the yaml loader.') # try the absolute path yaml_abs = os.path.abspath(yaml_file) if os.path.exists(yaml_abs): return yaml_abs # try adding the path to savu if len(yaml_file.split('Savu/')) > 1: yaml_savu = os.path.join(os.path.dirname(__file__), "../../../", yaml_file.split('Savu/')[1]) if os.path.exists(yaml_savu): return yaml_savu # try adding the path to the templates folder yaml_templ = os.path.join(os.path.dirname(__file__), yaml_file) if os.path.exists(yaml_templ): return yaml_templ raise Exception('The yaml file does not exist %s' % yaml_file) def _add_template_updates(self, ddict): all_entries = ddict.pop('all', {}) for key, value in all_entries: for entry in ddict: if key in list(entry.keys()): entry[key] = value for entry in self.parameters['template_param']: updates = self.parameters['template_param'][entry] ddict[entry]['params'].update(updates) return ddict def _check_for_imports(self, ddict): if 'import' in list(ddict.keys()): for imp in ddict['import']: name = False if len(imp.split()) > 1: imp, name = imp.split('as') mod = __import__(imp.strip()) globals()[mod.__name__ if not name else name] = mod def _check_for_inheritance(self, ddict, inherit, override=False): if 'inherit' in list(ddict.keys()): idict = ddict['inherit'] idict = idict if isinstance(idict, list) else [idict] for i in idict: if i != 'None': new_dict = yu.read_yaml(self._get_yaml_file(i)) new_dict, isoverride = \ self.__override(inherit, new_dict, override) inherit.update(new_dict) inherit = self._check_for_inheritance( new_dict, inherit, override=isoverride) self._update(inherit, ddict) return inherit def __override(self, inherit, ddict, override): isoverride = False if 'override' in ddict: isoverride = ddict.pop('override') if override: for old, new in override.items(): ddict[new] = ddict.pop(old) if new in list(inherit.keys()): self._update(ddict[new], inherit[new]) return ddict, isoverride def _update(self, d, u): for k, v in u.items(): if isinstance(v, collections.Mapping): d[k] = self._update(d.get(k, {}), v) else: d[k] = v return d def _set_entries(self, ddict): entries = list(ddict.keys()) for name in entries: self.get_description(ddict[name], name)
[docs] def get_description(self, entry, name, metadata=True): # set params first as we may need them subsequently if 'params' in entry: self._set_params(entry['params']) # --------------- check for data entry ----------------------------- if 'data' in list(entry.keys()): data_obj = self.exp.create_data_object("in_data", name) data_obj = self.set_data(data_obj, entry['data']) else: emsg = 'Please specify the data information in the yaml file.' raise Exception(emsg) if metadata: self._get_meta_data_descriptions(entry, data_obj)
def _get_meta_data_descriptions(self, entry, data_obj): # --------------- check for axis label information ----------------- if 'axis_labels' in list(entry.keys()): self._set_axis_labels(data_obj, entry['axis_labels']) else: raise Exception('Please specify the axis labels in the yaml file.') # --------------- check for data access patterns ------------------- if 'patterns' in list(entry.keys()): self._set_patterns(data_obj, entry['patterns']) else: raise Exception('Please specify the patterns in the yaml file.') # add any additional metadata if 'metadata' in entry: self._set_metadata(data_obj, entry['metadata']) self.set_data_reduction_params(data_obj) if 'exp_metadata' in entry: self._set_metadata(data_obj, entry['exp_metadata'], exp=True)
[docs] def set_data(self, name, entry): raise NotImplementedError('Please implement "set_data" function' ' in the loader')
def _set_keywords(self, dObj): filepath = str(dObj.backing_file.filename) shape = str(dObj.get_shape()) return {'dfile': filepath, 'dshape': shape} def __get_wildcard_values(self, dObj): if 'wildcard_values' in list(dObj.data_info.get_dictionary().keys()): return dObj.data_info.get('wildcard_values') return None
[docs] def update_value(self, dObj, value, itr=0): import pdb # setting the keywords if dObj is not None: dshape = dObj.get_shape() dfile = dObj.backing_file globals()['dfile'] = dfile wildcard = self.__get_wildcard_values(dObj) if isinstance(value, str): split = value.split('$') if len(split) > 1: value = self._convert_string(dObj, split[1]) try: value = eval(value, globals(), locals()) value = self._convert_bytes(value) except Exception as e: msg = (f"Error evaluating value: '{value}' \n %s" % e) try: value = value.replace("index(", "index(b") value = eval(value, globals(), locals()) value = self._convert_bytes(value) except: raise Exception(msg) return value
def _convert_string(self, dObj, string): for old, new in self.parameters.items(): if old in string: if isinstance(new, str): split = new.split('$') if len(split) > 1: new = split[1] elif isinstance(new, str): # nothing left to split new = "'%s'" % new string = self._convert_string( dObj, string.replace(old, str(new))) return string def _convert_bytes(self, value): # convert bytes to str - for back compatability if isinstance(value, bytes): return value.decode("ascii") if isinstance(value, np.ndarray) and isinstance(value[0], bytes): return value.astype(str) return value def _set_params(self, params): # Update variable parameters that are revealed in the template params = self._update_template_params(params) self.parameters.update(params) # find files, open and add to the namespace then delete file params files = [k for k in list(params.keys()) if k.endswith('file')] for f in files: param = params[f] try: globals()[str(f)] = self.update_value(None, param) except IOError: self._check_for_test_data(f, param) del params[f] def _check_for_test_data(self, f, param): # check if this is Savu test data substrs = param.split("'")[1:2] filename = None for s in substrs: try: filename = self._get_yaml_file(s) break except: pass param = param.replace(s, filename) globals()[str(f)] = self.update_value(None, param) del self.parameters[f] def _update_template_params(self, params): for k, v in params.items(): v = pu.is_template_param(v) if v is not False: params[k] = \ self.parameters[k] if k in list(self.parameters.keys()) else v[1] return params def _set_axis_labels(self, dObj, labels): dims = list(range(len(list(labels.keys())))) axis_labels = [None]*len(list(labels.keys())) for d in dims: self._check_label_entry(labels[d]) l = labels[d] for key in list(l.keys()): l[key] = self.update_value(dObj, l[key]) axis_labels[l['dim']] = (l['name'] + '.' + l['units']) if l['value'] is not None: dObj.meta_data.set(l['name'], l['value']) dObj.set_axis_labels(*axis_labels) def _check_label_entry(self, label): required = ['dim', 'name', 'value', 'units'] try: [label[i] for i in required] except: raise Exception("name, value and units are required fields for \ axis labels") def _set_patterns(self, dObj, patterns): for key, dims in patterns.items(): core_dims = self.__get_tuple( self.update_value(dObj, dims['core_dims'])) slice_dims = self.__get_tuple( self.update_value(dObj, dims['slice_dims'])) dObj.add_pattern(key, core_dims=core_dims, slice_dims=slice_dims) def __get_tuple(self, val): return literal_eval(val) if not isinstance(val, tuple) else val def _set_metadata(self, dObj, mdata, exp=False): populate = dObj.exp if exp else dObj for key, value in mdata.items(): value = self.update_value(dObj, value['value']) populate.meta_data.set(key, value)