# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
.. module:: time_based_correction
:platform: Unix
:synopsis: A time-based dark and flat field correction using linear\
interpolation
.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>
"""
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
from savu.plugins.driver.cpu_plugin import CpuPlugin
from savu.plugins.corrections.base_correction import BaseCorrection
from savu.plugins.utils import register_plugin
[docs]@register_plugin
class TimeBasedCorrection(BaseCorrection, CpuPlugin):
def __init__(self, name="TimeBasedCorrection"):
super(TimeBasedCorrection, self).__init__(name)
[docs] def pre_process(self):
self.count = 0
inData = self.get_in_datasets()[0]
pData = self.get_plugin_in_datasets()[0]
self.mfp = inData._get_plugin_data()._get_max_frames_process()
self.proj_dim = \
inData.get_data_dimension_by_axis_label('rotation_angle')
self.slice_dir = pData.get_slice_dimension()
nDims = len(pData.get_shape())
self.sslice = [slice(None)]*nDims
self.image_key = inData.data.get_image_key()
changes = np.where(np.diff(self.image_key) != 0)[0] + 1
self.split_key = np.split(self.image_key, changes)
self.split_idx = np.split(np.arange(len(self.image_key)), changes)
self.data_key = inData.data.get_index(0)
self.dark, self.dark_idx = self.calc_average(inData.data.dark(), 2)
self.flat, self.flat_idx = self.calc_average(inData.data.flat(), 1)
inData.meta_data.set('multiple_dark', self.dark)
inData.meta_data.set('multiple_flat', self.flat)
[docs] def calc_average(self, data, key):
im_key = np.where(self.image_key == key)[0]
splits = np.where(np.diff(im_key) > 1)[0]+1
local_idx = np.split(np.arange(len(im_key)), splits)
mean_data = [np.mean(data[np.array(local_idx[i])], axis=0)
for i in range(len(local_idx))]
list_idx = list(np.where([key in i for i in self.split_key])[0])
return mean_data, list_idx
[docs] def process_frames(self, data):
proj = data[0]
frame = self.get_global_frame_index()[self.count]
flat = self.calculate_flat_field(frame, proj,
*self.find_nearest_frames(self.flat_idx, frame))
dark = self.calculate_dark_field(frame, proj,
*self.find_nearest_frames(self.dark_idx, frame))
if self.parameters['in_range']:
proj = self.in_range(proj, flat)
self.count += 1
return np.nan_to_num((proj-dark)/(flat-dark))
[docs] def in_range(self, data, flat):
data[data > flat] = flat[data > flat]
return data
[docs] def find_nearest_frames(self, idx_list, value):
""" Find the index of the two entries that 'value' lies between in \
'idx_list' and calculate the distance between each of them.
"""
global_val = self.data_key[value]
# find which list (index) global_val belongs to
list_idx = [global_val in i for i in self.split_idx].index(True)
val_list = self.split_idx[list_idx]
# find length of list
length_list = len(val_list)
# find position of global_val in list and distance from each end
pos = np.where(val_list == global_val)[0][0]
dist = [(length_list-pos)/float(length_list), pos/float(length_list)]
# find closest before and after idx_list entries
new_list = list(np.sort(np.append(idx_list, list_idx)))
new_idx = new_list.index(list_idx)
entry1 = new_idx-1 if new_idx != 0 else new_idx+1
entry2 = new_idx+1 if new_idx != len(new_list)-1 else new_idx-1
before = idx_list.index(new_list[entry1])
after = idx_list.index(new_list[entry2])
return [before, after], dist
[docs] def calculate_flat_field(self, frame, data, frames, distance):
return self.flat[frames[0]]*distance[0] + \
self.flat[frames[1]]*distance[1]
[docs] def calculate_dark_field(self, frame, data, frames, distance):
return self.dark[frames[0]]*distance[0] + \
self.dark[frames[1]]*distance[1]
[docs] def get_max_frames(self):
return 'single'