Source code for plugins.loaders.hdf5_template_loader

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: hdf5_template_loader
   :platform: Unix
   :synopsis: A class to load data from a non-standard nexus/hdf5 file using \
   descriptions loaded from a yaml file.

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""

import os
import h5py
import fnmatch
import difflib
import numpy as np

import savu.core.utils as cu
from savu.plugins.utils import register_plugin
from savu.plugins.loaders.yaml_converter import YamlConverter
from savu.data.data_structures.data_types.stitch_data import StitchData


[docs]@register_plugin class Hdf5TemplateLoader(YamlConverter): def __init__(self, name='Hdf5TemplateLoader'): super(Hdf5TemplateLoader, self).__init__(name)
[docs] def set_data(self, dObj, data): path = data['path'] if 'path' in list(data.keys()) else None if not path: emsg = 'Please specify the path to the data in the h5 file.' raise Exception(emsg) file_path = self.exp.meta_data.get("data_file") if 'file' not in \ list(data.keys()) else data['file'] file_path = self.update_value(dObj, file_path) dObj.backing_file = h5py.File(file_path, 'r') basename = os.path.basename(path) if len(basename.split('*')) > 1: return self._stitch_data(dObj, path, data) return self._setup_data(dObj, path)
def _stitch_data(self, dObj, path, data): stype, dim = self._get_stitching_info(data) remove = data['remove'] if 'remove' in list(data.keys()) else None group_name, data_name = os.path.split(path) # find all files with the given name group = dObj.backing_file.require_group(group_name) matches = fnmatch.filter(list(group.keys()), data_name) number = [] for m in matches: diff_number = '' for diff in difflib.ndiff(m, data_name): split = diff.split('- ') if len(split) > 1: diff_number += split[-1] number.append(int(diff_number)) matches = [matches[i] for i in np.argsort(number)] dObj.data_info.set('wildcard_values', sorted(number)) data_obj_list = [] for match in matches: match_path = os.path.join(group_name, match) sub_obj = self.exp.create_data_object('in_data', match) sub_obj.backing_file = dObj.backing_file data_obj_list.append(self._setup_data(sub_obj, match_path)) del self.exp.index['in_data'][match] if data_obj_list: dObj.data = StitchData(data_obj_list, stype, dim, remove=remove) dObj.set_original_shape(dObj.data.get_shape()) else: cu.user_message("The data set %s is empty." % data_name) return dObj def _get_stitching_info(self, data): if 'stack' in list(data.keys()): return 'stack', data['stack'] elif 'cat' in list(data.keys()): return 'cat', data['cat'] else: msg = 'Please specify the dimension to stack or concatenate.' raise Exception(msg) def _setup_data(self, dObj, path): path = self.update_value(dObj, path) if path in dObj.backing_file: dObj.data = dObj.backing_file[self.update_value(dObj, path)] dObj.set_shape(dObj.data.shape) else: raise Exception("The path '%s' was not found in %s" % (path, dObj.backing_file.filename)) return dObj