Source code for plugins.alignment.projection_shift

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: projection_shift
   :platform: Unix
   :synopsis: Calculate horizontal and vertical shifts in the projection\
       images over time, using template matching.

.. moduleauthor:: Nicola Wadeson <scientificsoftware@diamond.ac.uk>

"""

import logging
import numpy as np
from skimage.feature import match_template, match_descriptors, ORB
from scipy.linalg import lstsq
from skimage.transform import AffineTransform
from skimage.measure import ransac

from savu.plugins.utils import register_plugin
from savu.plugins.filters.base_filter import BaseFilter
from savu.plugins.driver.cpu_plugin import CpuPlugin


[docs]@register_plugin class ProjectionShift(BaseFilter, CpuPlugin): """ """ def __init__(self): logging.debug("initialising Sinogram Alignment") super(ProjectionShift, self).__init__("ProjectionShift") self.template = None self.threshold = 0
[docs] def pre_process(self): if self.parameters['method'] == 'template_matching': self.template_params = [] for p in self.parameters['template']: start, end = p.split(':') self.template_params.append(slice(int(start), int(end))) self._calculate_shift = self._template_matching_shift elif self.parameters['method'] == 'orb_ransac': self._calculate_shift = self._orb_ransac_shift if self.parameters['threshold']: self.threshold = self.parameters['threshold'] self.sl = [slice(None)]*3 self.sl2 = [slice(None)]*3 self.slice_dir = self.get_plugin_in_datasets()[0].get_slice_dimension() self.A = self._calculate_frame_matrix()
def _calculate_frame_matrix(self): n_unknowns = self.get_max_frames() + 2 # 2 padded frames frame_list = self._calculate_frame_list(np.arange(n_unknowns)) n_equations = len(frame_list) A = np.zeros((n_equations, n_unknowns)) for i in range(len(frame_list)): for f in frame_list[i][1:]: A[i, f] = 1 return A
[docs] def process_frames(self, data): data, nFrames, output, shift_array = self._initial_setup(data) return self._sub_pixel_shift_adjustment(data)
def _initial_setup(self, data): data = data[0] shape = list(data.shape) nFrames = data.shape[self.slice_dir]-2 shape[self.slice_dir] += -2 output = np.zeros(tuple(shape)) shift_array = np.zeros((nFrames, 2)) return data, nFrames, output, shift_array def _get_shift(self, data, frame1, frame2): self.sl[self.slice_dir] = frame1 self.sl2[self.slice_dir] = frame2 d1 = data[self.sl] d2 = data[self.sl2] if self.template: self.template = data[self.sl][self.template_params] if self.threshold: d1[d1 > self.threshold[0]] = self.threshold[1] d2[d2 > self.threshold[0]] = self.threshold[1] if self.template: self.template[self.template > self.threshold[0]] = \ self.threshold[1] return self._calculate_shift(d1, d2, self.template) def _orb_ransac_shift(self, im1, im2, template): descriptor_extractor = ORB() #n_keypoints=self.parameters['n_keypoints']) key1, des1 = self._find_key_points(descriptor_extractor, im1) key2, des2 = self._find_key_points(descriptor_extractor, im2) matches = match_descriptors(des1, des2, cross_check=True) # estimate affine transform model using all coordinates src = key1[matches[:, 0]] dst = key2[matches[:, 1]] # robustly estimate affine transform model with RANSAC model_robust, inliers = ransac((src, dst), AffineTransform, min_samples=3, residual_threshold=1, max_trials=100) # diff = [] # for p1, p2 in zip(src[inliers], dst[inliers]): # diff.append(p2-p1) # return np.mean(diff, axis=0) return model_robust.translation def _find_key_points(self, desc_extractor, image): desc_extractor.detect_and_extract(image) keypoints = desc_extractor.keypoints descriptors = desc_extractor.descriptors return keypoints, descriptors def _template_matching_shift(self, im1, im2, template): index = [] for im in [im1, im2]: match = match_template(im, template) index.append(np.unravel_index(np.argmax(match), match.shape)) index = np.array(index) shift = index[1] - index[0] return shift def _sub_pixel_shift_adjustment(self, data): frame_list = \ self._calculate_frame_list(np.arange(data.shape[self.slice_dir])) new_shift = [] for f in frame_list: new_shift.append( self._get_shift(data, f[0], f[-1]).astype(np.float64)) return self._calculate_new_shift_array(np.array(new_shift)) def _calculate_frame_list(self, frames): sixes = list(zip(*(frames[i:] for i in range(6)))) fives = list(zip(*(frames[i:] for i in range(5)))) fours = list(zip(*(frames[i:] for i in range(4)))) threes = list(zip(*(frames[i:] for i in range(3)))) return sixes + fives + fours + threes def _calculate_new_shift_array(self, shift): new_shift = [] for i in range(2): new_shift.append(lstsq(self.A, shift[:, i])[0]) return np.transpose(np.array(new_shift))[1:-1]
[docs] def post_process(self): out_data = self.get_out_datasets()[0] self.get_in_datasets()[0].meta_data.set( 'proj_align_shift_local', out_data.data[:, :]) self.get_in_datasets()[0].meta_data.set( 'proj_align_shift', np.cumsum(out_data.data[:, :], axis=0))
[docs] def get_max_frames(self): # Do not change this number as 8 is currently a requirement. return 8
[docs] def nOutput_datasets(self): return 1
[docs] def setup(self): # set up the output dataset that is created by the plugin in_dataset, out_dataset = self.get_datasets() in_pData, out_pData = self.get_plugin_datasets() in_pData[0].plugin_data_setup('PROJECTION', self.get_max_frames(), fixed=True) new_shape = (in_dataset[0].get_shape()[ in_dataset[0].get_slice_directions()[0]], 2) out_dataset[0].create_dataset(shape=new_shape, axis_labels=['x.pixels', 'y.pixels'], remove=True) out_dataset[0].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,)) out_pData[0].plugin_data_setup('METADATA', self.get_max_frames(), fixed=True)
[docs] def set_filter_padding(self, in_data, out_data): pad_dim = in_data[0].get_slice_directions()[0] in_data[0].padding = {'pad_directions': [str(pad_dim) + '.1']}
#in_data[0].padding = {'pad_directions': [str(pad_dim) + '.before.1']}