Source code for plugins.stats.comparison

# Copyright 2014 Diamond Light Source Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

.. module:: comparison
   :platform: Unix
   :synopsis: A plugin to compare two datasets, given as input datasets, and print the RMSD between the two.
              The data is unchanged.

.. moduleauthor:: Jacob Williamson <>

from savu.plugins.utils import register_plugin
from savu.plugins.plugin import Plugin
from savu.plugins.driver.cpu_plugin import CpuPlugin
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding

import numpy as np

# This decorator is required for the configurator to recognise the plugin
[docs]@register_plugin class Comparison(Plugin, CpuPlugin): def __init__(self): super(Comparison, self).__init__("Comparison")
[docs] def nInput_datasets(self): return 2
[docs] def nOutput_datasets(self): if check_if_end_plugin_in_iterate_group(self.exp): return 3 else: return 2
[docs] def nClone_datasets(self): if check_if_end_plugin_in_iterate_group(self.exp): return 1 else: return 0
@enable_iterative_loop def setup(self): # This method is called after the number of in/out datasets associated # with the plugin has been established. It tells the framework all # the information it needs to know about the data transport to-and-from # the plugin. # ================== Input and output datasets ========================= # in_datasets and out_datasets are instances of the Data class. # in_datasets were either created in the loader or as output from # previous plugins. out_datasets objects have already been created at # this point, but they are empty and need to be populated. # Get the Data instances associated with this plugin in_dataset, out_dataset = self.get_datasets() # see # for more information on creating datasets. # Populate the output dataset(s) out_dataset[0].create_dataset(in_dataset[0]) out_dataset[1].create_dataset(in_dataset[1]) self.rss_list = [] self.flipped_rss_list = [] self.data_points_list = [] self.partial_cc_top = [] self.partial_cc_bottom = ([], []) # ================== Input and output plugin datasets ================== # in_pData and out_pData are instances of the PluginData class. # All in_datasets and out_datasets above have an in/out_pData object # attached to them temporarily for the duration of the plugin, # giving access to additional plugin-specific dataset details. At this # point they have been created but not yet populated. # Get the PluginData instances attached to the Data instances above in_pData, out_pData = self.get_plugin_datasets() # Each plugin dataset must call this method and define the data access # pattern and number of frames required. for i in range(len(in_pData)): in_pData[i].plugin_data_setup(self.parameters['pattern'], 'single') # 'single', 'multiple' or an int (should only be used if essential) out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') out_pData[1].plugin_data_setup(self.parameters['pattern'], 'single') # All dataset information can be accessed via the Data and PluginData # instances
[docs] def pre_process(self): # This method is called once before any processing has begun. # Access parameters from the doc string in the parameters dictionary # e.g. self.parameters['example'] in_datasets = self.get_in_datasets() self.names = [in_datasets[0].group_name, in_datasets[1].group_name] if not self.names[0]: self.names[0] = "dataset1" if not self.names[1]: self.names[1] = "dataset2" self.stats = [None, None] self.ranges = [None, None] try: self.stats[0] = self.stats_obj.get_stats_from_dataset(in_datasets[0]) # get stats dictionary self.ranges[0] = self.stats[0]["max"] - self.stats[0]["min"] except KeyError: print(f"Can't find stats metadata in {self.names[0]}, cannot do comparison") try: self.stats[1] = self.stats_obj.get_stats_from_dataset(in_datasets[1]) self.ranges[1] = self.stats[1]["max"] - self.stats[1]["min"] except KeyError: print(f"Can't find stats metadata in {self.names[1]}, cannot do comparison")
[docs] def process_frames(self, data): # This function is called in a loop by the framework until all the # data has been processed. # Each iteration of the loop will receive a list of numpy arrays # (data) containing nInput_datasets with the data sliced as requested # in the setup method (SINOGRAM in this case). If 'multiple' or an # integer number of max_frames are requested the array with have an # extra dimension. # This plugin has one output dataset, so a single numpy array (a # SINOGRAM in this case) should be returned to the framework. if data[0].shape == data[1].shape: if self.stats[0] is not None and self.stats[1] is not None: scaled_data = [self._scale_data(data[0], self.stats[0]["min"], self.ranges[0]), self._scale_data(data[1], self.stats[1]["min"], self.ranges[1])] self.rss_list.append(self.stats_obj.calc_rss(scaled_data[0], scaled_data[1])) self.data_points_list.append(data[0].size) flipped_data = 1 - scaled_data[0] self.flipped_rss_list.append(self.stats_obj.calc_rss(flipped_data, scaled_data[1])) self.partial_cc_top.append(np.sum((data[0] - self.stats[0]["mean"]) * (data[1] - self.stats[1]["mean"]))) self.partial_cc_bottom[0].append(np.sum((data[0] - self.stats[0]["mean"]) ** 2)) self.partial_cc_bottom[1].append(np.sum((data[1] - self.stats[1]["mean"]) ** 2)) else: print("Arrays different sizes, can't calculated residuals.") return [data[0], data[1]]
def _scale_data(self, data, vol_min, vol_range, new_min=0, new_range=1): # scale data slice to be between 0 and 1 data = data - vol_min data = data * (new_range/vol_range) data = data + new_min return data
[docs] def post_process(self): if self.stats[0] is not None and self.stats[1] is not None: total_rss = sum(self.rss_list) total_data = sum(self.data_points_list) RMSD = self.stats_obj.rmsd_from_rss(total_rss, total_data) print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {RMSD}") total_flipped_rss = sum(self.flipped_rss_list) FRMSD = self.stats_obj.rmsd_from_rss(total_flipped_rss, total_data) print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {FRMSD}, \ when the contrast is flipped") PCC = np.sum(self.partial_cc_top) / (np.sqrt(np.sum(self.partial_cc_bottom[0]) * np.sum(self.partial_cc_bottom[1]))) print(f"Pearson correlation coefficient between {self.names[0]} and {self.names[1]} is {PCC}")