Source code for plugins.reconstructions.projectors.forward_projector_gpu

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: forward_projector_gpu
   :platform: Unix
   :synopsis: A forward data projector using ToMoBAR software

.. moduleauthor:: Daniil Kazantsev <scientificsoftware@diamond.ac.uk>
"""

from savu.plugins.plugin import Plugin
from savu.plugins.driver.gpu_plugin import GpuPlugin
from savu.plugins.utils import register_plugin
from savu.core.iterate_plugin_group_utils import check_if_in_iterative_loop
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding

from tomobar.methodsDIR import RecToolsDIR
import numpy as np
import copy

[docs]@register_plugin class ForwardProjectorGpu(Plugin, GpuPlugin): def __init__(self): super(ForwardProjectorGpu, self).__init__('ForwardProjectorGpu') self.pad = None self.angles_total = None self.det_horiz_half = None self.detectors_horiz = None self.projection_shifts = None self.cor = None self.angles_rad = None @setup_extra_plugin_data_padding def set_filter_padding(self, in_pData, out_pData): self.pad = self.parameters['padding'] in_data = self.get_in_datasets()[0] pad_slice_dir_ext = in_data.get_data_dimension_by_axis_label('voxel_y') pad_slice_dir = '%s.%s' % (pad_slice_dir_ext, self.pad) pad_dict = {'pad_directions': [pad_slice_dir], 'pad_mode': 'edge'} in_pData[0].padding = pad_dict out_pData[0].padding = pad_dict
[docs] def pre_process(self): # getting metadata for CoR in_meta_data = self.get_in_meta_data()[0] self.cor = in_meta_data.get('centre_of_rotation') self.cor = np.mean(self.cor) # CoR must be a scalar for 3D geometry
@enable_iterative_loop def setup(self): in_dataset, out_dataset = self.get_datasets() in_pData, out_pData = self.get_plugin_datasets() in_pData[0].plugin_data_setup('VOLUME_XZ', 'multiple') in_meta_data = self.get_in_meta_data()[0] # extracting parameters from metadata angles_meta_deg = in_meta_data.get('rotation_angle') try: if self.exp.meta_data.get("synthetic") == True: self.angles_rad = np.deg2rad(angles_meta_deg) else: self.angles_rad = -np.deg2rad(angles_meta_deg) except KeyError: self.angles_rad = -np.deg2rad(angles_meta_deg) self.detectors_horiz = in_meta_data.get('detector_x_length') # get experimental metadata of projection_shifts if 'projection_shifts' in list(self.exp.meta_data.dict.keys()): self.projection_shifts = self.exp.meta_data.dict['projection_shifts'] # deal with user-defined parameters if self.parameters['angles_deg'] is not None: angles_list = self.parameters['angles_deg'] self.angles_rad = np.deg2rad(np.linspace(angles_list[0], angles_list[1], angles_list[2], dtype=np.float)) if self.parameters['centre_of_rotation'] is not None: self.cor = self.parameters['centre_of_rotation'] if self.parameters['det_horiz'] is not None: self.detectors_horiz = self.parameters['det_horiz'] self.det_horiz_half = 0.5 * self.detectors_horiz self.angles_total = len(self.angles_rad) out_shape_sino = self.new_shape(in_dataset[0].get_shape(), in_dataset[0]) labels = ['rotation_angle.degrees', 'detector_y.pixel', 'detector_x.pixel'] pattern = {'name': 'SINOGRAM', 'slice_dims': (1,), 'core_dims': (2, 0)} out_dataset[0].create_dataset(axis_labels=labels, shape=out_shape_sino) out_dataset[0].add_pattern(pattern['name'], slice_dims=pattern['slice_dims'], core_dims=pattern['core_dims']) pattern2 = {'name': 'PROJECTION', 'slice_dims': (0,), 'core_dims': (1, 2)} out_dataset[0].add_pattern(pattern2['name'], slice_dims=pattern2['slice_dims'], core_dims=pattern2['core_dims']) out_pData[0].plugin_data_setup(pattern['name'], self.get_max_frames()) out_dataset[0].meta_data.set('rotation_angle', copy.deepcopy(angles_meta_deg))
[docs] def process_frames(self, data): object_to_project = data[0].astype(np.float32) object_to_project = np.where(np.isfinite(object_to_project), object_to_project, 0) object_size = np.shape(object_to_project)[0] vert_size = None # 2D case # dealing with 3D data case if object_to_project.ndim == 3: vert_size = np.shape(object_to_project)[1] iterate_group = check_if_in_iterative_loop(self.exp) if iterate_group is not None and \ iterate_group._ip_iteration > 0 and \ 'projection_shifts' in list(self.exp.meta_data.dict.keys()): # update projection_shifts from experimental metadata self.projection_shifts = \ self.exp.meta_data.dict['projection_shifts'] cor = np.zeros((np.shape(self.projection_shifts))) cor[:, 0] = (-self.cor + self.det_horiz_half - 0.5) cor[:, 1] = -0.5 registration = False for plugin_dict in self.exp.meta_data.plugin_list.plugin_list: if plugin_dict['name'] == 'Projection2dAlignment': registration = plugin_dict['data']['registration'] break if not registration: # modify the offset to take into account the shifts cor[:, 0] -= self.projection_shifts[:, 0] cor[:, 1] -= self.projection_shifts[:, 1] else: iterate_group = check_if_in_iterative_loop(self.exp) if iterate_group is None: self.angles_rad = -self.angles_rad else: # only apply the sign change on iteration 0, not on subsequent # iterations if iterate_group._ip_iteration == 0: self.angles_rad = -self.angles_rad cor = (-self.cor + self.det_horiz_half - 0.5) RectoolsDIRECT = RecToolsDIR(DetectorsDimH=self.detectors_horiz, # DetectorsDimH # detector dimension (horizontal) DetectorsDimV=vert_size, # DetectorsDimV # detector dimension (vertical) CenterRotOffset=cor, # Center of Rotation AnglesVec=self.angles_rad, # array of angles in radians ObjSize=object_size, # a scalar to define reconstructed object dimensions device_projector=self.parameters['GPU_index']) if vert_size is not None: projected = RectoolsDIRECT.FORWPROJ(np.require(np.swapaxes(object_to_project, 0, 1), requirements='CA')) projected = np.require(np.swapaxes(projected, 0, 1), requirements='CA') else: projected = RectoolsDIRECT.FORWPROJ(object_to_project) return projected
[docs] def new_shape(self, full_shape, data): # calculate a new output data shape based on the input data shape new_shape_sino_orig = list(full_shape) new_shape_sino = (self.angles_total, new_shape_sino_orig[1], self.detectors_horiz) return tuple(new_shape_sino)
[docs] def post_process(self): # populate CoR value in_datasets, out_datasets = self.get_datasets() sdirs = in_datasets[0].get_slice_dimensions() cor_vect = np.ones(np.prod([in_datasets[0].get_shape()[i] for i in sdirs])) self.cor *= cor_vect out_datasets[0].meta_data.set('centre_of_rotation', copy.deepcopy(self.cor))
[docs] def get_max_frames(self): return 'multiple'
[docs] def nInput_datasets(self): return 1
[docs] def nOutput_datasets(self): return 1