Source code for savu.data.experiment_collection

# -*- coding: utf-8 -*-
# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: experiment
   :platform: Unix
   :synopsis: Contains information specific to the entire experiment.

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>
"""

import os
import copy
import h5py
import logging
from mpi4py import MPI

from savu.data.meta_data import MetaData
from savu.data.plugin_list import PluginList
from savu.data.data_structures.data import Data
from savu.core.checkpointing import Checkpointing
from savu.core.iterative_plugin_runner import IteratePluginGroup
from savu.plugins.savers.utils.hdf5_utils import Hdf5Utils
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
import savu.plugins.loaders.utils.yaml_utils as yaml


[docs]class Experiment(object): """ One instance of this class is created at the beginning of the processing chain and remains until the end. It holds the current data object and a dictionary containing all metadata. """ def __init__(self, options): self.meta_data = MetaData(options) self.__set_system_params() self.checkpoint = Checkpointing(self) self.__meta_data_setup(options["process_file"]) self.collection = {} self.index = {"in_data": {}, "out_data": {}} self.initial_datasets = None self.plugin = None self._transport = None self._barrier_count = 0 self._dataset_names_complete = False
[docs] def get(self, entry): """ Get the meta data dictionary. """ return self.meta_data.get(entry)
def __meta_data_setup(self, process_file): self.meta_data.plugin_list = PluginList() try: rtype = self.meta_data.get('run_type') if rtype == 'test': self.meta_data.plugin_list.plugin_list = \ self.meta_data.get('plugin_list') else: raise Exception('the run_type is unknown in Experiment class') except KeyError: template = self.meta_data.get('template') self.meta_data.plugin_list._populate_plugin_list(process_file, template=template) self.meta_data.set("nPlugin", 0) # initialise self.meta_data.set('iterate_groups', [])
[docs] def create_data_object(self, dtype, name, override=True): """ Create a data object. Plugin developers should apply this method in loaders only. :params str dtype: either "in_data" or "out_data". """ if name not in list(self.index[dtype].keys()) or override: self.index[dtype][name] = Data(name, self) data_obj = self.index[dtype][name] data_obj._set_transport_data(self.meta_data.get('transport')) return self.index[dtype][name]
def _setup(self, transport): self._set_nxs_file() self._set_process_list_path() self._set_transport(transport) self.collection = {'plugin_dict': [], 'datasets': []} self._setup_iterate_plugin_groups(transport) self._barrier() self._check_checkpoint() self._barrier() def _setup_iterate_plugin_groups(self, transport): ''' Create all the necessary instances of IteratePluginGroup ''' iterate_plugin_groups = [] iterate_group_dicts = self.meta_data.plugin_list.iterate_plugin_groups for group in iterate_group_dicts: iterate_plugin_group = IteratePluginGroup(transport, group['start_index'], group['end_index'], group['iterations']) iterate_plugin_groups.append(iterate_plugin_group) self.meta_data.set('iterate_groups', iterate_plugin_groups) def _finalise_setup(self, plugin_list): checkpoint = self.meta_data.get('checkpoint') self._set_dataset_names_complete() # save the plugin list - one process, first time only if self.meta_data.get('process') == \ len(self.meta_data.get('processes'))-1 and not checkpoint: # Save original process list plugin_list._save_plugin_list(self.meta_data.get('process_list_path')) # links the input data to the nexus file if self.meta_data.get("pre_run"): self._create_pre_run_nxs_file() else: plugin_list._save_plugin_list(self.meta_data.get('nxs_filename')) self._add_input_data_to_nxs_file(self._get_transport()) self._set_dataset_names_complete() self._save_command_log() def _save_command_log(self): """Save the original Savu run command and a modified Savu run command to a log file for reproducibility """ current_path = os.getcwd() folder = self.meta_data.get('out_path') log_folder = os.path.join(folder, "run_log") filename = os.path.join(log_folder, "run_command.txt") modified_command = self._get_modified_command() if not os.path.isfile(filename): # Only write savu command if savu_mpi command has not been saved with open(filename, 'w') as command_log: command_log.write(f"# The directory the command was executed from\n") command_log.write(f"{current_path}\n") command_log.write(f"# Original Savu run command\n") command_log.write(f"{self.meta_data.get('command')}\n") command_log.write(f"# A modified Savu command to use to " f"reproduce the obtained result\n") command_log.write(f"{modified_command}\n") def _get_modified_command(self): """Modify the input Savu run command, and replace the path to the process list :returns modified Savu run command string """ pl_path = self.meta_data.get('process_file') new_pl_path = self.meta_data.get('process_list_path') input_command = self.meta_data.get('command') updated_command = input_command.replace(pl_path, new_pl_path) return updated_command def _save_pre_run_log(self): current_path = os.getcwd() folder = self.meta_data.get('out_path') log_folder = os.path.join(folder, "run_log") filename = os.path.join(log_folder, "pre_run_log.txt") if not os.path.isfile(filename): with open(filename, 'w') as pre_run_log: pre_run_log.write(f"# SAVU PRE-RUN\n") pre_run_log.write(f"# During the pre-run, the following process list was run:\n") pre_run_log.write(f"{self.meta_data.get('process_file_name')}\n") pre_run_log.write(f"# The following statistics were calculated on the input data:\n") if "pre_run_stats" in self.meta_data.get_dictionary().keys(): for key, value in self.meta_data.get("pre_run_stats").items(): pre_run_log.write(f" {key}: {value}\n") if "pre_run_preview" in self.meta_data.get_dictionary().keys(): pre_run_log.write(f"# The following value for the preview parameter was calculated from the input data:\n") pre_run_log.write(f" {self.meta_data.get('pre_run_preview')}") if len(self.meta_data.get("warnings")) != 0: pre_run_log.write(f"# Please read the following warnings before deciding whether to continue:\n") for warning in self.meta_data.get("warnings"): pre_run_log.write(f" ~ {warning}") def _set_process_list_path(self): """Create the path the process list should be saved to""" log_folder = os.path.join(self.meta_data.get('out_path'), "run_log") plname = os.path.basename(self.meta_data.get('process_file')) filename = os.path.join(log_folder, plname if plname else "process_list.nxs") self.meta_data.set('process_list_path', filename) def _set_process_list_path(self): """Create the path the process list should be saved to""" log_folder = os.path.join(self.meta_data.get('out_path'),"run_log") plname = os.path.basename(self.meta_data.get('process_file')) filename = os.path.join(log_folder, plname if plname else "process_list.nxs") self.meta_data.set('process_list_path', filename) def _set_initial_datasets(self): self.initial_datasets = copy.deepcopy(self.index['in_data']) def _set_transport(self, transport): self._transport = transport def _get_transport(self): return self._transport def __set_system_params(self): sys_file = self.meta_data.get('system_params') import sys if sys_file is None: # look in conda environment to see which version is being used savu_path = sys.modules['savu'].__path__[0] sys_files = os.path.join( os.path.dirname(savu_path), 'system_files') subdirs = os.listdir(sys_files) sys_folder = 'dls' if len(subdirs) > 1 else subdirs[0] fname = 'system_parameters.yml' sys_file = os.path.join(sys_files, sys_folder, fname) logging.info('Using the system parameters file: %s', sys_file) self.meta_data.set('system_params', yaml.read_yaml(sys_file)) def _check_checkpoint(self): # if checkpointing has been set but the nxs file doesn't contain an # entry then remove checkpointing (as the previous run didn't get far # enough to require it). if self.meta_data.get('checkpoint'): with h5py.File(self.meta_data.get('nxs_filename'), 'r') as f: if 'entry' not in f: self.meta_data.set('checkpoint', None) def _add_input_data_to_nxs_file(self, transport): # save the loaded data to file h5 = Hdf5Utils(self) for name, data in self.index['in_data'].items(): self.meta_data.set(['link_type', name], 'input_data') self.meta_data.set(['group_name', name], name) self.meta_data.set(['filename', name], data.backing_file) transport._populate_nexus_file(data) h5._link_datafile_to_nexus_file(data) def _create_pre_run_nxs_file(self): data_path = self.meta_data["data_path"] for name, data in self.index["in_data"].items(): raw_data = data.backing_file folder = self.meta_data['out_path'] fname = self.meta_data.get('datafile_name') + '_pre_run.nxs' filename = os.path.join(folder, fname) self.meta_data.set("pre_run_filename", filename) self.__copy_input_file_to_output_folder(raw_data, filename) if isinstance(raw_data.get(data_path, getlink=True), h5py.ExternalLink): link = raw_data.get(data_path, getlink=True) location = f'{"/".join(self.meta_data.get("data_file").split("/")[:-1])}/{link.filename}' #new_filename = os.path.join(folder, link.filename) #with h5py.File(location, "r") as linked_file: # self.__copy_input_file_to_output_folder(linked_file, new_filename) with h5py.File(filename, "r+") as new_file: del new_file[data_path] new_file[data_path] = h5py.ExternalLink(location, link.path) pass def __copy_input_file_to_output_folder(self, file, new_filename): with h5py.File(new_filename, "w") as new_file: for group_name in file.keys(): file.copy(file[group_name], new_file["/"], group_name) def _set_dataset_names_complete(self): """ Missing in/out_datasets fields have been populated """ self._dataset_names_complete = True def _get_dataset_names_complete(self): return self._dataset_names_complete def _reset_datasets(self): self.index['in_data'] = self.initial_datasets # clear out dataset dictionaries for data_dict in self.collection['datasets']: for data in data_dict.values(): data.meta_data._set_dictionary({}) def _get_collection(self): return self.collection def _set_experiment_for_current_plugin(self, count): datasets_list = self.meta_data.plugin_list._get_datasets_list()[count:] exp_coll = self._get_collection() self.index['out_data'] = exp_coll['datasets'][count] if datasets_list: self._get_current_and_next_patterns(datasets_list) self.meta_data.set('nPlugin', count) def _get_current_and_next_patterns(self, datasets_lists): """ Get the current and next patterns associated with a dataset throughout the processing chain. """ current_datasets = datasets_lists[0] patterns_list = {} for current_data in current_datasets['out_datasets']: current_name = current_data['name'] current_pattern = current_data['pattern'] next_pattern = self.__find_next_pattern(datasets_lists[1:], current_name) patterns_list[current_name] = \ {'current': current_pattern, 'next': next_pattern} self.meta_data.set('current_and_next', patterns_list) def __find_next_pattern(self, datasets_lists, current_name): next_pattern = [] for next_data_list in datasets_lists: for next_data in next_data_list['in_datasets']: if next_data['name'] == current_name: next_pattern = next_data['pattern'] return next_pattern return next_pattern def _set_nxs_file(self): folder = self.meta_data.get('out_path') if self.meta_data.get("pre_run") == True: fname = self.meta_data.get('datafile_name') + '_pre_run.nxs' else: fname = self.meta_data.get('datafile_name') + '_processed.nxs' filename = os.path.join(folder, fname) self.meta_data.set('nxs_filename', filename) if self.meta_data.get('process') == 1: if self.meta_data.get('bllog'): log_folder_name = self.meta_data.get('bllog') with open(log_folder_name, 'a') as log_folder: log_folder.write(os.path.abspath(filename) + '\n') self._create_nxs_entry() def _create_nxs_entry(self): # what if the file already exists?! logging.debug("Testing nexus file") if self.meta_data.get('process') == len( self.meta_data.get('processes')) - 1 and not self.checkpoint: with h5py.File(self.meta_data.get('nxs_filename'), 'w') as nxs_file: entry_group = nxs_file.create_group('entry') entry_group.attrs['NX_class'] = 'NXentry' def _clear_data_objects(self): self.index["out_data"] = {} self.index["in_data"] = {} def _merge_out_data_to_in(self, plugin_dict): out_data = self.index['out_data'].copy() for key, data in out_data.items(): if data.remove is False: self.index['in_data'][key] = data self.collection['datasets'].append(out_data) self.collection['plugin_dict'].append(plugin_dict) self.index["out_data"] = {} def _finalise_experiment_for_current_plugin(self): finalise = {'remove': [], 'keep': []} # populate nexus file with out_dataset information and determine which # datasets to remove from the framework. for key, data in self.index['out_data'].items(): if data.remove is True: finalise['remove'].append(data) else: finalise['keep'].append(data) # find in datasets to replace finalise['replace'] = [] if not check_if_in_iterative_loop(self): for out_name in list(self.index['out_data'].keys()): if out_name in list(self.index['in_data'].keys()): finalise['replace'].append(self.index['in_data'][out_name]) else: # temporary workaround to # https://jira.diamond.ac.uk/browse/SCI-10216: don't mark any # datasets as "to replace" if the given plugin is in an iterative # loop logging.debug('Not marking any datasets in a loop as '\ '\"to replace\"') return finalise def _reorganise_datasets(self, finalise): # unreplicate replicated in_datasets self.__unreplicate_data() # delete all datasets for removal for data in finalise['remove']: del self.index["out_data"][data.data_info.get('name')] # Add remaining output datasets to input datasets for name, data in self.index['out_data'].items(): data.get_preview().set_preview([]) self.index["in_data"][name] = copy.deepcopy(data) self.index['out_data'] = {} def __unreplicate_data(self): in_data_list = self.index['in_data'] from savu.data.data_structures.data_types.replicate import Replicate for in_data in list(in_data_list.values()): if isinstance(in_data.data, Replicate): in_data.data = in_data.data._reset() def _set_all_datasets(self, name): data_names = [] for key in list(self.index["in_data"].keys()): if 'itr_clone' not in key: data_names.append(key) return data_names def _barrier(self, communicator=MPI.COMM_WORLD, msg=''): comm_dict = {'comm': communicator} if self.meta_data.get('mpi') is True: logging.debug("Barrier %d: %d processes expected: %s", self._barrier_count, communicator.size, msg) comm_dict['comm'].barrier() self._barrier_count += 1
[docs] def log(self, log_tag, log_level=logging.DEBUG): """ Log the contents of the experiment at the specified level """ logging.log(log_level, "Experimental Parameters for %s", log_tag) for key, value in self.index["in_data"].items(): logging.log(log_level, "in data (%s) shape = %s", key, value.get_shape()) for key, value in self.index["in_data"].items(): logging.log(log_level, "out data (%s) shape = %s", key, value.get_shape())