Source code for plugins.loaders.random_hdf5_loader

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: random_hdf5_loader
   :platform: Unix
   :synopsis: A loader that creates a random number generated hdf5 dataset of\
       any size.

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""

import os
import h5py
import logging
import numpy as np

from savu.data.chunking import Chunking
from savu.plugins.utils import register_plugin
from savu.plugins.loaders.base_loader import BaseLoader
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils


[docs]@register_plugin class RandomHdf5Loader(BaseLoader): def __init__(self, name='RandomHdf5Loader'): super(RandomHdf5Loader, self).__init__(name)
[docs] def setup(self): exp = self.exp data_obj = exp.create_data_object('in_data', self.parameters['dataset_name']) data_obj.set_axis_labels(*self.parameters['axis_labels']) self.__convert_patterns(data_obj) self.__parameter_checks(data_obj) data_obj.backing_file = self.__get_backing_file(data_obj) data_obj.data = data_obj.backing_file['/']['test'] data_obj.data.dtype # Need to do something to .data to keep the file open! data_obj.set_shape(data_obj.data.shape) self.n_entries = data_obj.get_shape()[0] self._set_rotation_angles(data_obj, self._get_n_entries()) return data_obj
def __get_backing_file(self, data_obj): fname = '%s/%s.h5' % \ (self.exp.get('out_path'), self.parameters['file_name']) if os.path.exists(fname): return h5py.File(fname, 'r') self.hdf5 = Hdf5Utils(self.exp) size = tuple(self.parameters['size']) patterns = data_obj.get_data_patterns() p_name = patterns[self.parameters['pattern']] if \ self.parameters['pattern'] is not None else list(patterns.keys())[0] p_name = list(patterns.keys())[0] p_dict = patterns[p_name] p_dict['max_frames_transfer'] = 1 nnext = {p_name: p_dict} pattern_idx = {'current': nnext, 'next': nnext} chunking = Chunking(self.exp, pattern_idx) chunks = chunking._calculate_chunking(size, np.int16) h5file = self.hdf5._open_backing_h5(fname, 'w') dset = h5file.create_dataset('test', size, chunks=chunks) self.exp._barrier() slice_dirs = list(nnext.values())[0]['slice_dims'] nDims = len(dset.shape) total_frames = np.prod([dset.shape[i] for i in slice_dirs]) sub_size = \ [1 if i in slice_dirs else dset.shape[i] for i in range(nDims)] # need an mpi barrier after creating the file before populating it idx = 0 sl, total_frames = \ self.__get_start_slice_list(slice_dirs, dset.shape, total_frames) # calculate the first slice for i in range(total_frames): low, high = self.parameters['range'] dset[tuple(sl)] = np.random.randint( low, high=high, size=sub_size, dtype=self.parameters['dtype_']) if sl[slice_dirs[idx]].stop == dset.shape[slice_dirs[idx]]: idx += 1 if idx == len(slice_dirs): break tmp = sl[slice_dirs[idx]] sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) self.exp._barrier() try: h5file.close() except IOError as exc: logging.debug('There was a problem trying to close the file in random_hdf5_loader') return self.hdf5._open_backing_h5(fname, 'r') def __get_start_slice_list(self, slice_dirs, shape, n_frames): n_processes = len(self.exp.get('processes')) rank = self.exp.get('process') frames = np.array_split(np.arange(n_frames), n_processes)[rank] f_range = list(range(0, frames[0])) if len(frames) else [] sl = [slice(0, 1) if i in slice_dirs else slice(None) for i in range(len(shape))] idx = 0 for i in f_range: if sl[slice_dirs[idx]] == shape[slice_dirs[idx]]-1: idx += 1 tmp = sl[slice_dirs[idx]] sl[slice_dirs[idx]] = slice(tmp.start+1, tmp.stop+1) return sl, len(frames) def __convert_patterns(self, data_obj): pattern_list = self.parameters['patterns'] for p in pattern_list: p_split = p.split('.') name = p_split[0] dims = p_split[1:] core_dims = tuple([int(i[0]) for i in [d.split('c') for d in dims] if len(i) == 2]) slice_dims = tuple([int(i[0]) for i in [d.split('s') for d in dims] if len(i) == 2]) data_obj.add_pattern( name, core_dims=core_dims, slice_dims=slice_dims) def _set_rotation_angles(self, data_obj, n_entries): angles = self.parameters['angles'] if angles is None: angles = np.linspace(0, 180, n_entries) else: try: angles = eval(angles) except: raise Exception('Cannot set angles in loader.') n_angles = len(angles) data_angles = n_entries if data_angles != n_angles: raise Exception("The number of angles %s does not match the data " "dimension length %s", n_angles, data_angles) data_obj.meta_data.set("rotation_angle", angles) def __parameter_checks(self, data_obj): if not self.parameters['size']: raise Exception( 'Please specifiy the size of the dataset to create.') def _get_n_entries(self): return self.n_entries