Source code for plugins.corrections.time_based_plus_drift_correction

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: time_based_plus_drift_correction
   :platform: Unix
   :synopsis: A time-based dark and flat field correction that accounts for\
       image drift (Note: A work in progress but please try).

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""

import numpy as np
from skimage.feature import match_template
from scipy.ndimage.interpolation import shift as sci_shift

from savu.plugins.driver.cpu_plugin import CpuPlugin
from savu.plugins.corrections.time_based_correction import TimeBasedCorrection
from savu.plugins.utils import register_plugin


[docs]@register_plugin class TimeBasedPlusDriftCorrection(TimeBasedCorrection, CpuPlugin): def __init__(self): super(TimeBasedPlusDriftCorrection, self).__init__( "TimeBasedPlusDriftCorrection")
[docs] def pre_process(self): super(TimeBasedPlusDriftCorrection, self).pre_process() self.shift_array = np.zeros((len(self.data_key), 2)) # find shift between flat field frames self.template = self.flat[0][tuple(self.set_template_params())] self.drift = self.calculate_flat_field_drift(self.template)
[docs] def set_template_params(self): template_params = [] for p in self.parameters['template']: start, end = p.split(':') template_params.append(slice(int(start), int(end))) return template_params
[docs] def calculate_flat_field_drift(self, template): drift = [] for i in range(len(self.flat)-1): drift.append(self.calculate_shift( self.flat[i], self.flat[i+1], template)) return drift
[docs] def calculate_shift(self, im1, im2, template): index = [] for im in [im1, im2]: match = match_template(im, template) index.append(np.unravel_index(np.argmax(match), match.shape)) index = np.array(index) shift = index[1] - index[0] return shift
[docs] def calculate_flat_field(self, frame, data, frames, distance): shift = self.calculate_shift(self.flat[frames[0]], data, self.template) flat1 = sci_shift(self.flat[frames[0]], tuple(shift), cval=np.nan) flat2 = sci_shift(self.flat[frames[1]], shift-self.drift[frames[0]], cval=np.nan) flat1, flat2 = self.fill_nans(flat1, flat2) if frames[0] > 0: shift = shift + self.drift[frames[0]-1] self.shift_array[np.where(self.data_key == frame)[0]] = shift return flat1*distance[0] + flat2*distance[1]
[docs] def fill_nans(self, im1, im2): im1[np.isnan(im1)] = im2[np.isnan(im1)] im2[np.isnan(im2)] = im1[np.isnan(im2)] return im1, im2
[docs] def post_process(self): inData = self.get_in_datasets()[0] inData.meta_data.set('shift', self.shift_array)