# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
.. module:: plugin_driver
:platform: Unix
:synopsis: Base class or all driver plugins
.. moduleauthor:: Mark Basham <scientificsoftware@diamond.ac.uk>
"""
import logging
import numpy as np
from mpi4py import MPI
from savu.plugins.driver.basic_driver import BasicDriver
[docs]class PluginDriver(BasicDriver):
def __init__(self):
super(PluginDriver, self).__init__()
self._communicator = None
def _run_plugin_instances(self, transport, communicator=MPI.COMM_WORLD):
""" Runs the pre_process, process and post_process methods.
If parameter tuning is required, loop over the methods and set the
correct parameters for each run. """
self.__set_communicator(communicator)
out_data = self.get_out_datasets()
extra_dims = self.get_plugin_tools().extra_dims
repeat = np.prod(extra_dims) if extra_dims else 1
param_idx = self.__calc_param_indices(extra_dims)
out_data_dims = [len(d.get_shape()) for d in out_data]
param_dims = [list(range(d - len(extra_dims), d)) for d in out_data_dims]
if extra_dims:
init_vars = self.__get_local_dict()
for i in range(repeat):
if extra_dims:
self.__reset_local_vars(init_vars)
self.get_plugin_tools()._set_parameters_this_instance(
param_idx[i])
for j in range(len(out_data)):
out_data[j]._get_plugin_data()\
.set_fixed_dimensions(param_dims[j], param_idx[i])
super(PluginDriver, self).\
_run_plugin_instances(transport, communicator=communicator)
self._reset_process_frames_counter()
self._revert_preview(self.parameters['in_datasets'])
for j in range(len(out_data)):
out_data[j].set_shape(out_data[j].data.shape)
def __get_local_dict(self):
""" Gets the local variables of the class minus those from the Plugin
class. """
from savu.plugins.plugin import Plugin
plugin = Plugin()
copy_keys = vars(self).keys() - vars(plugin).keys()
copy_dict = {}
for key in copy_keys:
copy_dict[key] = getattr(self, key)
return copy_dict
def __reset_local_vars(self, copy_dict):
""" Resets the class variables in copy_dict. """
for key, value in copy_dict.items():
setattr(self, key, value)
def __calc_param_indices(self, dims):
indices_list = []
for i in range(len(dims)):
chunk = int(np.prod(dims[0:i]))
repeat = int(np.prod(dims[i+1:]))
idx = np.ravel(np.kron(list(range(dims[i])), np.ones((repeat, chunk))))
indices_list.append(idx.astype(int))
return np.transpose(np.array(indices_list))
def __set_communicator(self, comm):
self._communicator = comm
[docs] def get_communicator(self):
return self._communicator